123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674 |
- """Small plotting-related utility functions."""
- import colorsys
- import os
- import numpy as np
- from scipy import stats
- import pandas as pd
- import matplotlib as mpl
- import matplotlib.colors as mplcol
- import matplotlib.pyplot as plt
- import warnings
- from urllib.request import urlopen, urlretrieve
- from http.client import HTTPException
- __all__ = ["desaturate", "saturate", "set_hls_values",
- "despine", "get_dataset_names", "get_data_home", "load_dataset"]
- def remove_na(arr):
- """Helper method for removing NA values from array-like.
- Parameters
- ----------
- arr : array-like
- The array-like from which to remove NA values.
- Returns
- -------
- clean_arr : array-like
- The original array with NA values removed.
- """
- return arr[pd.notnull(arr)]
- def sort_df(df, *args, **kwargs):
- """Wrapper to handle different pandas sorting API pre/post 0.17."""
- msg = "This function is deprecated and will be removed in a future version"
- warnings.warn(msg)
- try:
- return df.sort_values(*args, **kwargs)
- except AttributeError:
- return df.sort(*args, **kwargs)
- def ci_to_errsize(cis, heights):
- """Convert intervals to error arguments relative to plot heights.
- Parameters
- ----------
- cis: 2 x n sequence
- sequence of confidence interval limits
- heights : n sequence
- sequence of plot heights
- Returns
- -------
- errsize : 2 x n array
- sequence of error size relative to height values in correct
- format as argument for plt.bar
- """
- cis = np.atleast_2d(cis).reshape(2, -1)
- heights = np.atleast_1d(heights)
- errsize = []
- for i, (low, high) in enumerate(np.transpose(cis)):
- h = heights[i]
- elow = h - low
- ehigh = high - h
- errsize.append([elow, ehigh])
- errsize = np.asarray(errsize).T
- return errsize
- def pmf_hist(a, bins=10):
- """Return arguments to plt.bar for pmf-like histogram of an array.
- DEPRECATED: will be removed in a future version.
- Parameters
- ----------
- a: array-like
- array to make histogram of
- bins: int
- number of bins
- Returns
- -------
- x: array
- left x position of bars
- h: array
- height of bars
- w: float
- width of bars
- """
- msg = "This function is deprecated and will be removed in a future version"
- warnings.warn(msg)
- n, x = np.histogram(a, bins)
- h = n / n.sum()
- w = x[1] - x[0]
- return x[:-1], h, w
- def desaturate(color, prop):
- """Decrease the saturation channel of a color by some percent.
- Parameters
- ----------
- color : matplotlib color
- hex, rgb-tuple, or html color name
- prop : float
- saturation channel of color will be multiplied by this value
- Returns
- -------
- new_color : rgb tuple
- desaturated color code in RGB tuple representation
- """
- # Check inputs
- if not 0 <= prop <= 1:
- raise ValueError("prop must be between 0 and 1")
- # Get rgb tuple rep
- rgb = mplcol.colorConverter.to_rgb(color)
- # Convert to hls
- h, l, s = colorsys.rgb_to_hls(*rgb)
- # Desaturate the saturation channel
- s *= prop
- # Convert back to rgb
- new_color = colorsys.hls_to_rgb(h, l, s)
- return new_color
- def saturate(color):
- """Return a fully saturated color with the same hue.
- Parameters
- ----------
- color : matplotlib color
- hex, rgb-tuple, or html color name
- Returns
- -------
- new_color : rgb tuple
- saturated color code in RGB tuple representation
- """
- return set_hls_values(color, s=1)
- def set_hls_values(color, h=None, l=None, s=None): # noqa
- """Independently manipulate the h, l, or s channels of a color.
- Parameters
- ----------
- color : matplotlib color
- hex, rgb-tuple, or html color name
- h, l, s : floats between 0 and 1, or None
- new values for each channel in hls space
- Returns
- -------
- new_color : rgb tuple
- new color code in RGB tuple representation
- """
- # Get an RGB tuple representation
- rgb = mplcol.colorConverter.to_rgb(color)
- vals = list(colorsys.rgb_to_hls(*rgb))
- for i, val in enumerate([h, l, s]):
- if val is not None:
- vals[i] = val
- rgb = colorsys.hls_to_rgb(*vals)
- return rgb
- def axlabel(xlabel, ylabel, **kwargs):
- """Grab current axis and label it."""
- ax = plt.gca()
- ax.set_xlabel(xlabel, **kwargs)
- ax.set_ylabel(ylabel, **kwargs)
- def despine(fig=None, ax=None, top=True, right=True, left=False,
- bottom=False, offset=None, trim=False):
- """Remove the top and right spines from plot(s).
- fig : matplotlib figure, optional
- Figure to despine all axes of, default uses current figure.
- ax : matplotlib axes, optional
- Specific axes object to despine.
- top, right, left, bottom : boolean, optional
- If True, remove that spine.
- offset : int or dict, optional
- Absolute distance, in points, spines should be moved away
- from the axes (negative values move spines inward). A single value
- applies to all spines; a dict can be used to set offset values per
- side.
- trim : bool, optional
- If True, limit spines to the smallest and largest major tick
- on each non-despined axis.
- Returns
- -------
- None
- """
- # Get references to the axes we want
- if fig is None and ax is None:
- axes = plt.gcf().axes
- elif fig is not None:
- axes = fig.axes
- elif ax is not None:
- axes = [ax]
- for ax_i in axes:
- for side in ["top", "right", "left", "bottom"]:
- # Toggle the spine objects
- is_visible = not locals()[side]
- ax_i.spines[side].set_visible(is_visible)
- if offset is not None and is_visible:
- try:
- val = offset.get(side, 0)
- except AttributeError:
- val = offset
- ax_i.spines[side].set_position(('outward', val))
- # Potentially move the ticks
- if left and not right:
- maj_on = any(
- t.tick1line.get_visible()
- for t in ax_i.yaxis.majorTicks
- )
- min_on = any(
- t.tick1line.get_visible()
- for t in ax_i.yaxis.minorTicks
- )
- ax_i.yaxis.set_ticks_position("right")
- for t in ax_i.yaxis.majorTicks:
- t.tick2line.set_visible(maj_on)
- for t in ax_i.yaxis.minorTicks:
- t.tick2line.set_visible(min_on)
- if bottom and not top:
- maj_on = any(
- t.tick1line.get_visible()
- for t in ax_i.xaxis.majorTicks
- )
- min_on = any(
- t.tick1line.get_visible()
- for t in ax_i.xaxis.minorTicks
- )
- ax_i.xaxis.set_ticks_position("top")
- for t in ax_i.xaxis.majorTicks:
- t.tick2line.set_visible(maj_on)
- for t in ax_i.xaxis.minorTicks:
- t.tick2line.set_visible(min_on)
- if trim:
- # clip off the parts of the spines that extend past major ticks
- xticks = np.asarray(ax_i.get_xticks())
- if xticks.size:
- firsttick = np.compress(xticks >= min(ax_i.get_xlim()),
- xticks)[0]
- lasttick = np.compress(xticks <= max(ax_i.get_xlim()),
- xticks)[-1]
- ax_i.spines['bottom'].set_bounds(firsttick, lasttick)
- ax_i.spines['top'].set_bounds(firsttick, lasttick)
- newticks = xticks.compress(xticks <= lasttick)
- newticks = newticks.compress(newticks >= firsttick)
- ax_i.set_xticks(newticks)
- yticks = np.asarray(ax_i.get_yticks())
- if yticks.size:
- firsttick = np.compress(yticks >= min(ax_i.get_ylim()),
- yticks)[0]
- lasttick = np.compress(yticks <= max(ax_i.get_ylim()),
- yticks)[-1]
- ax_i.spines['left'].set_bounds(firsttick, lasttick)
- ax_i.spines['right'].set_bounds(firsttick, lasttick)
- newticks = yticks.compress(yticks <= lasttick)
- newticks = newticks.compress(newticks >= firsttick)
- ax_i.set_yticks(newticks)
- def _kde_support(data, bw, gridsize, cut, clip):
- """Establish support for a kernel density estimate."""
- support_min = max(data.min() - bw * cut, clip[0])
- support_max = min(data.max() + bw * cut, clip[1])
- return np.linspace(support_min, support_max, gridsize)
- def percentiles(a, pcts, axis=None):
- """Like scoreatpercentile but can take and return array of percentiles.
- DEPRECATED: will be removed in a future version.
- Parameters
- ----------
- a : array
- data
- pcts : sequence of percentile values
- percentile or percentiles to find score at
- axis : int or None
- if not None, computes scores over this axis
- Returns
- -------
- scores: array
- array of scores at requested percentiles
- first dimension is length of object passed to ``pcts``
- """
- msg = "This function is deprecated and will be removed in a future version"
- warnings.warn(msg)
- scores = []
- try:
- n = len(pcts)
- except TypeError:
- pcts = [pcts]
- n = 0
- for i, p in enumerate(pcts):
- if axis is None:
- score = stats.scoreatpercentile(a.ravel(), p)
- else:
- score = np.apply_along_axis(stats.scoreatpercentile, axis, a, p)
- scores.append(score)
- scores = np.asarray(scores)
- if not n:
- scores = scores.squeeze()
- return scores
- def ci(a, which=95, axis=None):
- """Return a percentile range from an array of values."""
- p = 50 - which / 2, 50 + which / 2
- return np.percentile(a, p, axis)
- def sig_stars(p):
- """Return a R-style significance string corresponding to p values.
- DEPRECATED: will be removed in a future version.
- """
- msg = "This function is deprecated and will be removed in a future version"
- warnings.warn(msg)
- if p < 0.001:
- return "***"
- elif p < 0.01:
- return "**"
- elif p < 0.05:
- return "*"
- elif p < 0.1:
- return "."
- return ""
- def iqr(a):
- """Calculate the IQR for an array of numbers."""
- a = np.asarray(a)
- q1 = stats.scoreatpercentile(a, 25)
- q3 = stats.scoreatpercentile(a, 75)
- return q3 - q1
- def get_dataset_names():
- """Report available example datasets, useful for reporting issues."""
- # delayed import to not demand bs4 unless this function is actually used
- from bs4 import BeautifulSoup
- http = urlopen('https://github.com/mwaskom/seaborn-data/')
- gh_list = BeautifulSoup(http)
- return [l.text.replace('.csv', '')
- for l in gh_list.find_all("a", {"class": "js-navigation-open"})
- if l.text.endswith('.csv')]
- def get_data_home(data_home=None):
- """Return a path to the cache directory for example datasets.
- This directory is then used by :func:`load_dataset`.
- If the ``data_home`` argument is not specified, it tries to read from the
- ``SEABORN_DATA`` environment variable and defaults to ``~/seaborn-data``.
- """
- if data_home is None:
- data_home = os.environ.get('SEABORN_DATA',
- os.path.join('~', 'seaborn-data'))
- data_home = os.path.expanduser(data_home)
- if not os.path.exists(data_home):
- os.makedirs(data_home)
- return data_home
- def load_dataset(name, cache=True, data_home=None, **kws):
- """Load an example dataset from the online repository (requires internet).
- This function provides quick access to a small number of example datasets
- that are useful for documenting seaborn or generating reproducible examples
- for bug reports. It is not necessary for normal usage.
- Note that some of the datasets have a small amount of preprocessing applied
- to define a proper ordering for categorical variables.
- Use :func:`get_dataset_names` to see a list of available datasets.
- Parameters
- ----------
- name : str
- Name of the dataset (``{name}.csv`` on
- https://github.com/mwaskom/seaborn-data).
- cache : boolean, optional
- If True, try to load from the local cache first, and save to the cache
- if a download is required.
- data_home : string, optional
- The directory in which to cache data; see :func:`get_data_home`.
- kws : keys and values, optional
- Additional keyword arguments are passed to passed through to
- :func:`pandas.read_csv`.
- Returns
- -------
- df : :class:`pandas.DataFrame`
- Tabular data, possibly with some preprocessing applied.
- """
- path = ("https://raw.githubusercontent.com/"
- "mwaskom/seaborn-data/master/{}.csv")
- full_path = path.format(name)
- if cache:
- cache_path = os.path.join(get_data_home(data_home),
- os.path.basename(full_path))
- if not os.path.exists(cache_path):
- urlretrieve(full_path, cache_path)
- full_path = cache_path
- df = pd.read_csv(full_path, **kws)
- if df.iloc[-1].isnull().all():
- df = df.iloc[:-1]
- # Set some columns as a categorical type with ordered levels
- if name == "tips":
- df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"])
- df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"])
- df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"])
- df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"])
- if name == "flights":
- df["month"] = pd.Categorical(df["month"], df.month.unique())
- if name == "exercise":
- df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"])
- df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"])
- df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"])
- if name == "titanic":
- df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"])
- df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG"))
- return df
- def axis_ticklabels_overlap(labels):
- """Return a boolean for whether the list of ticklabels have overlaps.
- Parameters
- ----------
- labels : list of matplotlib ticklabels
- Returns
- -------
- overlap : boolean
- True if any of the labels overlap.
- """
- if not labels:
- return False
- try:
- bboxes = [l.get_window_extent() for l in labels]
- overlaps = [b.count_overlaps(bboxes) for b in bboxes]
- return max(overlaps) > 1
- except RuntimeError:
- # Issue on macos backend raises an error in the above code
- return False
- def axes_ticklabels_overlap(ax):
- """Return booleans for whether the x and y ticklabels on an Axes overlap.
- Parameters
- ----------
- ax : matplotlib Axes
- Returns
- -------
- x_overlap, y_overlap : booleans
- True when the labels on that axis overlap.
- """
- return (axis_ticklabels_overlap(ax.get_xticklabels()),
- axis_ticklabels_overlap(ax.get_yticklabels()))
- def categorical_order(values, order=None):
- """Return a list of unique data values.
- Determine an ordered list of levels in ``values``.
- Parameters
- ----------
- values : list, array, Categorical, or Series
- Vector of "categorical" values
- order : list-like, optional
- Desired order of category levels to override the order determined
- from the ``values`` object.
- Returns
- -------
- order : list
- Ordered list of category levels not including null values.
- """
- if order is None:
- if hasattr(values, "categories"):
- order = values.categories
- else:
- try:
- order = values.cat.categories
- except (TypeError, AttributeError):
- try:
- order = values.unique()
- except AttributeError:
- order = pd.unique(values)
- try:
- np.asarray(values).astype(np.float)
- order = np.sort(order)
- except (ValueError, TypeError):
- order = order
- order = filter(pd.notnull, order)
- return list(order)
- def locator_to_legend_entries(locator, limits, dtype):
- """Return levels and formatted levels for brief numeric legends."""
- raw_levels = locator.tick_values(*limits).astype(dtype)
- class dummy_axis:
- def get_view_interval(self):
- return limits
- if isinstance(locator, mpl.ticker.LogLocator):
- formatter = mpl.ticker.LogFormatter()
- else:
- formatter = mpl.ticker.ScalarFormatter()
- formatter.axis = dummy_axis()
- # TODO: The following two lines should be replaced
- # once pinned matplotlib>=3.1.0 with:
- # formatted_levels = formatter.format_ticks(raw_levels)
- formatter.set_locs(raw_levels)
- formatted_levels = [formatter(x) for x in raw_levels]
- return raw_levels, formatted_levels
- def get_color_cycle():
- """Return the list of colors in the current matplotlib color cycle
- Parameters
- ----------
- None
- Returns
- -------
- colors : list
- List of matplotlib colors in the current cycle, or dark gray if
- the current color cycle is empty.
- """
- cycler = mpl.rcParams['axes.prop_cycle']
- return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"]
- def relative_luminance(color):
- """Calculate the relative luminance of a color according to W3C standards
- Parameters
- ----------
- color : matplotlib color or sequence of matplotlib colors
- Hex code, rgb-tuple, or html color name.
- Returns
- -------
- luminance : float(s) between 0 and 1
- """
- rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
- rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
- lum = rgb.dot([.2126, .7152, .0722])
- try:
- return lum.item()
- except ValueError:
- return lum
- def to_utf8(obj):
- """Return a string representing a Python object.
- Strings (i.e. type ``str``) are returned unchanged.
- Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings.
- For other objects, the method ``__str__()`` is called, and the result is
- returned as a string.
- Parameters
- ----------
- obj : object
- Any Python object
- Returns
- -------
- s : str
- UTF-8-decoded string representation of ``obj``
- """
- if isinstance(obj, str):
- return obj
- try:
- return obj.decode(encoding="utf-8")
- except AttributeError: # obj is not bytes-like
- return str(obj)
- def _network(t=None, url='https://google.com'):
- """
- Decorator that will skip a test if `url` is unreachable.
- Parameters
- ----------
- t : function, optional
- url : str, optional
- """
- import nose
- if t is None:
- return lambda x: _network(x, url=url)
- def wrapper(*args, **kwargs):
- # attempt to connect
- try:
- f = urlopen(url)
- except (IOError, HTTPException):
- raise nose.SkipTest()
- else:
- f.close()
- return t(*args, **kwargs)
- return wrapper
|