category.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. """
  2. Plotting of string "category" data: ``plot(['d', 'f', 'a'], [1, 2, 3])`` will
  3. plot three points with x-axis values of 'd', 'f', 'a'.
  4. See :doc:`/gallery/lines_bars_and_markers/categorical_variables` for an
  5. example.
  6. The module uses Matplotlib's `matplotlib.units` mechanism to convert from
  7. strings to integers and provides a tick locator, a tick formatter, and the
  8. `.UnitData` class that creates and stores the string-to-integer mapping.
  9. """
  10. from collections import OrderedDict
  11. import dateutil.parser
  12. import itertools
  13. import logging
  14. import numpy as np
  15. from matplotlib import cbook, ticker, units
  16. _log = logging.getLogger(__name__)
  17. class StrCategoryConverter(units.ConversionInterface):
  18. @staticmethod
  19. def convert(value, unit, axis):
  20. """
  21. Convert strings in *value* to floats using mapping information stored
  22. in the *unit* object.
  23. Parameters
  24. ----------
  25. value : str or iterable
  26. Value or list of values to be converted.
  27. unit : `.UnitData`
  28. An object mapping strings to integers.
  29. axis : `~matplotlib.axis.Axis`
  30. The axis on which the converted value is plotted.
  31. .. note:: *axis* is unused.
  32. Returns
  33. -------
  34. mapped_value : float or ndarray[float]
  35. """
  36. if unit is None:
  37. raise ValueError(
  38. 'Missing category information for StrCategoryConverter; '
  39. 'this might be caused by unintendedly mixing categorical and '
  40. 'numeric data')
  41. # dtype = object preserves numerical pass throughs
  42. values = np.atleast_1d(np.array(value, dtype=object))
  43. # pass through sequence of non binary numbers
  44. if all(units.ConversionInterface.is_numlike(v)
  45. and not isinstance(v, (str, bytes))
  46. for v in values):
  47. return np.asarray(values, dtype=float)
  48. # force an update so it also does type checking
  49. unit.update(values)
  50. return np.vectorize(unit._mapping.__getitem__, otypes=[float])(values)
  51. @staticmethod
  52. def axisinfo(unit, axis):
  53. """
  54. Set the default axis ticks and labels.
  55. Parameters
  56. ----------
  57. unit : `.UnitData`
  58. object string unit information for value
  59. axis : `~matplotlib.axis.Axis`
  60. axis for which information is being set
  61. Returns
  62. -------
  63. axisinfo : `~matplotlib.units.AxisInfo`
  64. Information to support default tick labeling
  65. .. note: axis is not used
  66. """
  67. # locator and formatter take mapping dict because
  68. # args need to be pass by reference for updates
  69. majloc = StrCategoryLocator(unit._mapping)
  70. majfmt = StrCategoryFormatter(unit._mapping)
  71. return units.AxisInfo(majloc=majloc, majfmt=majfmt)
  72. @staticmethod
  73. def default_units(data, axis):
  74. """
  75. Set and update the `~matplotlib.axis.Axis` units.
  76. Parameters
  77. ----------
  78. data : str or iterable of str
  79. axis : `~matplotlib.axis.Axis`
  80. axis on which the data is plotted
  81. Returns
  82. -------
  83. class : `.UnitData`
  84. object storing string to integer mapping
  85. """
  86. # the conversion call stack is default_units -> axis_info -> convert
  87. if axis.units is None:
  88. axis.set_units(UnitData(data))
  89. else:
  90. axis.units.update(data)
  91. return axis.units
  92. class StrCategoryLocator(ticker.Locator):
  93. """Tick at every integer mapping of the string data."""
  94. def __init__(self, units_mapping):
  95. """
  96. Parameters
  97. -----------
  98. units_mapping : Dict[str, int]
  99. """
  100. self._units = units_mapping
  101. def __call__(self):
  102. return list(self._units.values())
  103. def tick_values(self, vmin, vmax):
  104. return self()
  105. class StrCategoryFormatter(ticker.Formatter):
  106. """String representation of the data at every tick."""
  107. def __init__(self, units_mapping):
  108. """
  109. Parameters
  110. ----------
  111. units_mapping : Dict[Str, int]
  112. """
  113. self._units = units_mapping
  114. def __call__(self, x, pos=None):
  115. """
  116. Return the category label string for tick val *x*.
  117. The position *pos* is ignored.
  118. """
  119. return self.format_ticks([x])[0]
  120. def format_ticks(self, values):
  121. r_mapping = {v: self._text(k) for k, v in self._units.items()}
  122. return [r_mapping.get(round(val), '') for val in values]
  123. @staticmethod
  124. def _text(value):
  125. """Convert text values into utf-8 or ascii strings."""
  126. if isinstance(value, bytes):
  127. value = value.decode(encoding='utf-8')
  128. elif not isinstance(value, str):
  129. value = str(value)
  130. return value
  131. class UnitData:
  132. def __init__(self, data=None):
  133. """
  134. Create mapping between unique categorical values and integer ids.
  135. Parameters
  136. ----------
  137. data : iterable
  138. sequence of string values
  139. """
  140. self._mapping = OrderedDict()
  141. self._counter = itertools.count()
  142. if data is not None:
  143. self.update(data)
  144. @staticmethod
  145. def _str_is_convertible(val):
  146. """
  147. Helper method to check whether a string can be parsed as float or date.
  148. """
  149. try:
  150. float(val)
  151. except ValueError:
  152. try:
  153. dateutil.parser.parse(val)
  154. except (ValueError, TypeError):
  155. # TypeError if dateutil >= 2.8.1 else ValueError
  156. return False
  157. return True
  158. def update(self, data):
  159. """
  160. Map new values to integer identifiers.
  161. Parameters
  162. ----------
  163. data : iterable
  164. sequence of string values
  165. Raises
  166. ------
  167. TypeError
  168. If the value in data is not a string, unicode, bytes type
  169. """
  170. data = np.atleast_1d(np.array(data, dtype=object))
  171. # check if convertible to number:
  172. convertible = True
  173. for val in OrderedDict.fromkeys(data):
  174. # OrderedDict just iterates over unique values in data.
  175. cbook._check_isinstance((str, bytes), value=val)
  176. if convertible:
  177. # this will only be called so long as convertible is True.
  178. convertible = self._str_is_convertible(val)
  179. if val not in self._mapping:
  180. self._mapping[val] = next(self._counter)
  181. if convertible:
  182. _log.info('Using categorical units to plot a list of strings '
  183. 'that are all parsable as floats or dates. If these '
  184. 'strings should be plotted as numbers, cast to the '
  185. 'appropriate data type before plotting.')
  186. # Register the converter with Matplotlib's unit framework
  187. units.registry[str] = StrCategoryConverter()
  188. units.registry[np.str_] = StrCategoryConverter()
  189. units.registry[bytes] = StrCategoryConverter()
  190. units.registry[np.bytes_] = StrCategoryConverter()