csvs.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """
  2. Module for formatting output data into CSV files.
  3. """
  4. import csv as csvlib
  5. from io import StringIO
  6. import os
  7. from typing import Hashable, List, Mapping, Optional, Sequence, Union
  8. import warnings
  9. from zipfile import ZipFile
  10. import numpy as np
  11. from pandas._libs import writers as libwriters
  12. from pandas._typing import FilePathOrBuffer
  13. from pandas.core.dtypes.generic import (
  14. ABCDatetimeIndex,
  15. ABCIndexClass,
  16. ABCMultiIndex,
  17. ABCPeriodIndex,
  18. )
  19. from pandas.core.dtypes.missing import notna
  20. from pandas.io.common import (
  21. get_compression_method,
  22. get_filepath_or_buffer,
  23. get_handle,
  24. infer_compression,
  25. )
  26. class CSVFormatter:
  27. def __init__(
  28. self,
  29. obj,
  30. path_or_buf: Optional[FilePathOrBuffer[str]] = None,
  31. sep: str = ",",
  32. na_rep: str = "",
  33. float_format: Optional[str] = None,
  34. cols=None,
  35. header: Union[bool, Sequence[Hashable]] = True,
  36. index: bool = True,
  37. index_label: Optional[Union[bool, Hashable, Sequence[Hashable]]] = None,
  38. mode: str = "w",
  39. encoding: Optional[str] = None,
  40. compression: Union[str, Mapping[str, str], None] = "infer",
  41. quoting: Optional[int] = None,
  42. line_terminator="\n",
  43. chunksize: Optional[int] = None,
  44. quotechar='"',
  45. date_format: Optional[str] = None,
  46. doublequote: bool = True,
  47. escapechar: Optional[str] = None,
  48. decimal=".",
  49. ):
  50. self.obj = obj
  51. if path_or_buf is None:
  52. path_or_buf = StringIO()
  53. # Extract compression mode as given, if dict
  54. compression, self.compression_args = get_compression_method(compression)
  55. self.path_or_buf, _, _, _ = get_filepath_or_buffer(
  56. path_or_buf, encoding=encoding, compression=compression, mode=mode
  57. )
  58. self.sep = sep
  59. self.na_rep = na_rep
  60. self.float_format = float_format
  61. self.decimal = decimal
  62. self.header = header
  63. self.index = index
  64. self.index_label = index_label
  65. self.mode = mode
  66. if encoding is None:
  67. encoding = "utf-8"
  68. self.encoding = encoding
  69. self.compression = infer_compression(self.path_or_buf, compression)
  70. if quoting is None:
  71. quoting = csvlib.QUOTE_MINIMAL
  72. self.quoting = quoting
  73. if quoting == csvlib.QUOTE_NONE:
  74. # prevents crash in _csv
  75. quotechar = None
  76. self.quotechar = quotechar
  77. self.doublequote = doublequote
  78. self.escapechar = escapechar
  79. self.line_terminator = line_terminator or os.linesep
  80. self.date_format = date_format
  81. self.has_mi_columns = isinstance(obj.columns, ABCMultiIndex)
  82. # validate mi options
  83. if self.has_mi_columns:
  84. if cols is not None:
  85. raise TypeError("cannot specify cols with a MultiIndex on the columns")
  86. if cols is not None:
  87. if isinstance(cols, ABCIndexClass):
  88. cols = cols.to_native_types(
  89. na_rep=na_rep,
  90. float_format=float_format,
  91. date_format=date_format,
  92. quoting=self.quoting,
  93. )
  94. else:
  95. cols = list(cols)
  96. self.obj = self.obj.loc[:, cols]
  97. # update columns to include possible multiplicity of dupes
  98. # and make sure sure cols is just a list of labels
  99. cols = self.obj.columns
  100. if isinstance(cols, ABCIndexClass):
  101. cols = cols.to_native_types(
  102. na_rep=na_rep,
  103. float_format=float_format,
  104. date_format=date_format,
  105. quoting=self.quoting,
  106. )
  107. else:
  108. cols = list(cols)
  109. # save it
  110. self.cols = cols
  111. # preallocate data 2d list
  112. self.blocks = self.obj._data.blocks
  113. ncols = sum(b.shape[0] for b in self.blocks)
  114. self.data = [None] * ncols
  115. if chunksize is None:
  116. chunksize = (100000 // (len(self.cols) or 1)) or 1
  117. self.chunksize = int(chunksize)
  118. self.data_index = obj.index
  119. if (
  120. isinstance(self.data_index, (ABCDatetimeIndex, ABCPeriodIndex))
  121. and date_format is not None
  122. ):
  123. from pandas import Index
  124. self.data_index = Index(
  125. [x.strftime(date_format) if notna(x) else "" for x in self.data_index]
  126. )
  127. self.nlevels = getattr(self.data_index, "nlevels", 1)
  128. if not index:
  129. self.nlevels = 0
  130. def save(self) -> None:
  131. """
  132. Create the writer & save.
  133. """
  134. # GH21227 internal compression is not used when file-like passed.
  135. if self.compression and hasattr(self.path_or_buf, "write"):
  136. warnings.warn(
  137. "compression has no effect when passing file-like object as input.",
  138. RuntimeWarning,
  139. stacklevel=2,
  140. )
  141. # when zip compression is called.
  142. is_zip = isinstance(self.path_or_buf, ZipFile) or (
  143. not hasattr(self.path_or_buf, "write") and self.compression == "zip"
  144. )
  145. if is_zip:
  146. # zipfile doesn't support writing string to archive. uses string
  147. # buffer to receive csv writing and dump into zip compression
  148. # file handle. GH21241, GH21118
  149. f = StringIO()
  150. close = False
  151. elif hasattr(self.path_or_buf, "write"):
  152. f = self.path_or_buf
  153. close = False
  154. else:
  155. f, handles = get_handle(
  156. self.path_or_buf,
  157. self.mode,
  158. encoding=self.encoding,
  159. compression=dict(self.compression_args, method=self.compression),
  160. )
  161. close = True
  162. try:
  163. # Note: self.encoding is irrelevant here
  164. self.writer = csvlib.writer(
  165. f,
  166. lineterminator=self.line_terminator,
  167. delimiter=self.sep,
  168. quoting=self.quoting,
  169. doublequote=self.doublequote,
  170. escapechar=self.escapechar,
  171. quotechar=self.quotechar,
  172. )
  173. self._save()
  174. finally:
  175. if is_zip:
  176. # GH17778 handles zip compression separately.
  177. buf = f.getvalue()
  178. if hasattr(self.path_or_buf, "write"):
  179. self.path_or_buf.write(buf)
  180. else:
  181. compression = dict(self.compression_args, method=self.compression)
  182. f, handles = get_handle(
  183. self.path_or_buf,
  184. self.mode,
  185. encoding=self.encoding,
  186. compression=compression,
  187. )
  188. f.write(buf)
  189. close = True
  190. if close:
  191. f.close()
  192. for _fh in handles:
  193. _fh.close()
  194. def _save_header(self):
  195. writer = self.writer
  196. obj = self.obj
  197. index_label = self.index_label
  198. cols = self.cols
  199. has_mi_columns = self.has_mi_columns
  200. header = self.header
  201. encoded_labels: List[str] = []
  202. has_aliases = isinstance(header, (tuple, list, np.ndarray, ABCIndexClass))
  203. if not (has_aliases or self.header):
  204. return
  205. if has_aliases:
  206. if len(header) != len(cols):
  207. raise ValueError(
  208. f"Writing {len(cols)} cols but got {len(header)} aliases"
  209. )
  210. else:
  211. write_cols = header
  212. else:
  213. write_cols = cols
  214. if self.index:
  215. # should write something for index label
  216. if index_label is not False:
  217. if index_label is None:
  218. if isinstance(obj.index, ABCMultiIndex):
  219. index_label = []
  220. for i, name in enumerate(obj.index.names):
  221. if name is None:
  222. name = ""
  223. index_label.append(name)
  224. else:
  225. index_label = obj.index.name
  226. if index_label is None:
  227. index_label = [""]
  228. else:
  229. index_label = [index_label]
  230. elif not isinstance(
  231. index_label, (list, tuple, np.ndarray, ABCIndexClass)
  232. ):
  233. # given a string for a DF with Index
  234. index_label = [index_label]
  235. encoded_labels = list(index_label)
  236. else:
  237. encoded_labels = []
  238. if not has_mi_columns or has_aliases:
  239. encoded_labels += list(write_cols)
  240. writer.writerow(encoded_labels)
  241. else:
  242. # write out the mi
  243. columns = obj.columns
  244. # write out the names for each level, then ALL of the values for
  245. # each level
  246. for i in range(columns.nlevels):
  247. # we need at least 1 index column to write our col names
  248. col_line = []
  249. if self.index:
  250. # name is the first column
  251. col_line.append(columns.names[i])
  252. if isinstance(index_label, list) and len(index_label) > 1:
  253. col_line.extend([""] * (len(index_label) - 1))
  254. col_line.extend(columns._get_level_values(i))
  255. writer.writerow(col_line)
  256. # Write out the index line if it's not empty.
  257. # Otherwise, we will print out an extraneous
  258. # blank line between the mi and the data rows.
  259. if encoded_labels and set(encoded_labels) != {""}:
  260. encoded_labels.extend([""] * len(columns))
  261. writer.writerow(encoded_labels)
  262. def _save(self) -> None:
  263. self._save_header()
  264. nrows = len(self.data_index)
  265. # write in chunksize bites
  266. chunksize = self.chunksize
  267. chunks = int(nrows / chunksize) + 1
  268. for i in range(chunks):
  269. start_i = i * chunksize
  270. end_i = min((i + 1) * chunksize, nrows)
  271. if start_i >= end_i:
  272. break
  273. self._save_chunk(start_i, end_i)
  274. def _save_chunk(self, start_i: int, end_i: int) -> None:
  275. data_index = self.data_index
  276. # create the data for a chunk
  277. slicer = slice(start_i, end_i)
  278. for i in range(len(self.blocks)):
  279. b = self.blocks[i]
  280. d = b.to_native_types(
  281. slicer=slicer,
  282. na_rep=self.na_rep,
  283. float_format=self.float_format,
  284. decimal=self.decimal,
  285. date_format=self.date_format,
  286. quoting=self.quoting,
  287. )
  288. for col_loc, col in zip(b.mgr_locs, d):
  289. # self.data is a preallocated list
  290. self.data[col_loc] = col
  291. ix = data_index.to_native_types(
  292. slicer=slicer,
  293. na_rep=self.na_rep,
  294. float_format=self.float_format,
  295. decimal=self.decimal,
  296. date_format=self.date_format,
  297. quoting=self.quoting,
  298. )
  299. libwriters.write_csv_rows(self.data, ix, self.nlevels, self.cols, self.writer)