figureoptions.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright © 2009 Pierre Raybaut
  2. # Licensed under the terms of the MIT License
  3. # see the Matplotlib licenses directory for a copy of the license
  4. """Module that provides a GUI-based editor for matplotlib's figure options."""
  5. import re
  6. import matplotlib
  7. from matplotlib import cbook, cm, colors as mcolors, markers, image as mimage
  8. from matplotlib.backends.qt_compat import QtGui
  9. from matplotlib.backends.qt_editor import _formlayout
  10. LINESTYLES = {'-': 'Solid',
  11. '--': 'Dashed',
  12. '-.': 'DashDot',
  13. ':': 'Dotted',
  14. 'None': 'None',
  15. }
  16. DRAWSTYLES = {
  17. 'default': 'Default',
  18. 'steps-pre': 'Steps (Pre)', 'steps': 'Steps (Pre)',
  19. 'steps-mid': 'Steps (Mid)',
  20. 'steps-post': 'Steps (Post)'}
  21. MARKERS = markers.MarkerStyle.markers
  22. def figure_edit(axes, parent=None):
  23. """Edit matplotlib figure options"""
  24. sep = (None, None) # separator
  25. # Get / General
  26. # Cast to builtin floats as they have nicer reprs.
  27. xmin, xmax = map(float, axes.get_xlim())
  28. ymin, ymax = map(float, axes.get_ylim())
  29. general = [('Title', axes.get_title()),
  30. sep,
  31. (None, "<b>X-Axis</b>"),
  32. ('Left', xmin), ('Right', xmax),
  33. ('Label', axes.get_xlabel()),
  34. ('Scale', [axes.get_xscale(), 'linear', 'log', 'logit']),
  35. sep,
  36. (None, "<b>Y-Axis</b>"),
  37. ('Bottom', ymin), ('Top', ymax),
  38. ('Label', axes.get_ylabel()),
  39. ('Scale', [axes.get_yscale(), 'linear', 'log', 'logit']),
  40. sep,
  41. ('(Re-)Generate automatic legend', False),
  42. ]
  43. # Save the unit data
  44. xconverter = axes.xaxis.converter
  45. yconverter = axes.yaxis.converter
  46. xunits = axes.xaxis.get_units()
  47. yunits = axes.yaxis.get_units()
  48. # Sorting for default labels (_lineXXX, _imageXXX).
  49. def cmp_key(label):
  50. match = re.match(r"(_line|_image)(\d+)", label)
  51. if match:
  52. return match.group(1), int(match.group(2))
  53. else:
  54. return label, 0
  55. # Get / Curves
  56. linedict = {}
  57. for line in axes.get_lines():
  58. label = line.get_label()
  59. if label == '_nolegend_':
  60. continue
  61. linedict[label] = line
  62. curves = []
  63. def prepare_data(d, init):
  64. """Prepare entry for FormLayout.
  65. *d* is a mapping of shorthands to style names (a single style may
  66. have multiple shorthands, in particular the shorthands `None`,
  67. `"None"`, `"none"` and `""` are synonyms); *init* is one shorthand
  68. of the initial style.
  69. This function returns an list suitable for initializing a
  70. FormLayout combobox, namely `[initial_name, (shorthand,
  71. style_name), (shorthand, style_name), ...]`.
  72. """
  73. if init not in d:
  74. d = {**d, init: str(init)}
  75. # Drop duplicate shorthands from dict (by overwriting them during
  76. # the dict comprehension).
  77. name2short = {name: short for short, name in d.items()}
  78. # Convert back to {shorthand: name}.
  79. short2name = {short: name for name, short in name2short.items()}
  80. # Find the kept shorthand for the style specified by init.
  81. canonical_init = name2short[d[init]]
  82. # Sort by representation and prepend the initial value.
  83. return ([canonical_init] +
  84. sorted(short2name.items(),
  85. key=lambda short_and_name: short_and_name[1]))
  86. curvelabels = sorted(linedict, key=cmp_key)
  87. for label in curvelabels:
  88. line = linedict[label]
  89. color = mcolors.to_hex(
  90. mcolors.to_rgba(line.get_color(), line.get_alpha()),
  91. keep_alpha=True)
  92. ec = mcolors.to_hex(
  93. mcolors.to_rgba(line.get_markeredgecolor(), line.get_alpha()),
  94. keep_alpha=True)
  95. fc = mcolors.to_hex(
  96. mcolors.to_rgba(line.get_markerfacecolor(), line.get_alpha()),
  97. keep_alpha=True)
  98. curvedata = [
  99. ('Label', label),
  100. sep,
  101. (None, '<b>Line</b>'),
  102. ('Line style', prepare_data(LINESTYLES, line.get_linestyle())),
  103. ('Draw style', prepare_data(DRAWSTYLES, line.get_drawstyle())),
  104. ('Width', line.get_linewidth()),
  105. ('Color (RGBA)', color),
  106. sep,
  107. (None, '<b>Marker</b>'),
  108. ('Style', prepare_data(MARKERS, line.get_marker())),
  109. ('Size', line.get_markersize()),
  110. ('Face color (RGBA)', fc),
  111. ('Edge color (RGBA)', ec)]
  112. curves.append([curvedata, label, ""])
  113. # Is there a curve displayed?
  114. has_curve = bool(curves)
  115. # Get ScalarMappables.
  116. mappabledict = {}
  117. for mappable in [*axes.images, *axes.collections]:
  118. label = mappable.get_label()
  119. if label == '_nolegend_' or mappable.get_array() is None:
  120. continue
  121. mappabledict[label] = mappable
  122. mappablelabels = sorted(mappabledict, key=cmp_key)
  123. mappables = []
  124. cmaps = [(cmap, name) for name, cmap in sorted(cm.cmap_d.items())]
  125. for label in mappablelabels:
  126. mappable = mappabledict[label]
  127. cmap = mappable.get_cmap()
  128. if cmap not in cm.cmap_d.values():
  129. cmaps = [(cmap, cmap.name), *cmaps]
  130. low, high = mappable.get_clim()
  131. mappabledata = [
  132. ('Label', label),
  133. ('Colormap', [cmap.name] + cmaps),
  134. ('Min. value', low),
  135. ('Max. value', high),
  136. ]
  137. if hasattr(mappable, "get_interpolation"): # Images.
  138. interpolations = [
  139. (name, name) for name in sorted(mimage.interpolations_names)]
  140. mappabledata.append((
  141. 'Interpolation',
  142. [mappable.get_interpolation(), *interpolations]))
  143. mappables.append([mappabledata, label, ""])
  144. # Is there a scalarmappable displayed?
  145. has_sm = bool(mappables)
  146. datalist = [(general, "Axes", "")]
  147. if curves:
  148. datalist.append((curves, "Curves", ""))
  149. if mappables:
  150. datalist.append((mappables, "Images, etc.", ""))
  151. def apply_callback(data):
  152. """This function will be called to apply changes"""
  153. orig_xlim = axes.get_xlim()
  154. orig_ylim = axes.get_ylim()
  155. general = data.pop(0)
  156. curves = data.pop(0) if has_curve else []
  157. mappables = data.pop(0) if has_sm else []
  158. if data:
  159. raise ValueError("Unexpected field")
  160. # Set / General
  161. (title, xmin, xmax, xlabel, xscale, ymin, ymax, ylabel, yscale,
  162. generate_legend) = general
  163. if axes.get_xscale() != xscale:
  164. axes.set_xscale(xscale)
  165. if axes.get_yscale() != yscale:
  166. axes.set_yscale(yscale)
  167. axes.set_title(title)
  168. axes.set_xlim(xmin, xmax)
  169. axes.set_xlabel(xlabel)
  170. axes.set_ylim(ymin, ymax)
  171. axes.set_ylabel(ylabel)
  172. # Restore the unit data
  173. axes.xaxis.converter = xconverter
  174. axes.yaxis.converter = yconverter
  175. axes.xaxis.set_units(xunits)
  176. axes.yaxis.set_units(yunits)
  177. axes.xaxis._update_axisinfo()
  178. axes.yaxis._update_axisinfo()
  179. # Set / Curves
  180. for index, curve in enumerate(curves):
  181. line = linedict[curvelabels[index]]
  182. (label, linestyle, drawstyle, linewidth, color, marker, markersize,
  183. markerfacecolor, markeredgecolor) = curve
  184. line.set_label(label)
  185. line.set_linestyle(linestyle)
  186. line.set_drawstyle(drawstyle)
  187. line.set_linewidth(linewidth)
  188. rgba = mcolors.to_rgba(color)
  189. line.set_alpha(None)
  190. line.set_color(rgba)
  191. if marker != 'none':
  192. line.set_marker(marker)
  193. line.set_markersize(markersize)
  194. line.set_markerfacecolor(markerfacecolor)
  195. line.set_markeredgecolor(markeredgecolor)
  196. # Set ScalarMappables.
  197. for index, mappable_settings in enumerate(mappables):
  198. mappable = mappabledict[mappablelabels[index]]
  199. if len(mappable_settings) == 5:
  200. label, cmap, low, high, interpolation = mappable_settings
  201. mappable.set_interpolation(interpolation)
  202. elif len(mappable_settings) == 4:
  203. label, cmap, low, high = mappable_settings
  204. mappable.set_label(label)
  205. mappable.set_cmap(cm.get_cmap(cmap))
  206. mappable.set_clim(*sorted([low, high]))
  207. # re-generate legend, if checkbox is checked
  208. if generate_legend:
  209. draggable = None
  210. ncol = 1
  211. if axes.legend_ is not None:
  212. old_legend = axes.get_legend()
  213. draggable = old_legend._draggable is not None
  214. ncol = old_legend._ncol
  215. new_legend = axes.legend(ncol=ncol)
  216. if new_legend:
  217. new_legend.set_draggable(draggable)
  218. # Redraw
  219. figure = axes.get_figure()
  220. figure.canvas.draw()
  221. if not (axes.get_xlim() == orig_xlim and axes.get_ylim() == orig_ylim):
  222. figure.canvas.toolbar.push_current()
  223. data = _formlayout.fedit(
  224. datalist, title="Figure options", parent=parent,
  225. icon=QtGui.QIcon(
  226. str(cbook._get_data_path('images', 'qt4_editor_options.svg'))),
  227. apply=apply_callback)
  228. if data is not None:
  229. apply_callback(data)