einsumfunc.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432
  1. """
  2. Implementation of optimized einsum.
  3. """
  4. from __future__ import division, absolute_import, print_function
  5. import itertools
  6. from numpy.compat import basestring
  7. from numpy.core.multiarray import c_einsum
  8. from numpy.core.numeric import asanyarray, tensordot
  9. from numpy.core.overrides import array_function_dispatch
  10. __all__ = ['einsum', 'einsum_path']
  11. einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
  12. einsum_symbols_set = set(einsum_symbols)
  13. def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
  14. """
  15. Computes the number of FLOPS in the contraction.
  16. Parameters
  17. ----------
  18. idx_contraction : iterable
  19. The indices involved in the contraction
  20. inner : bool
  21. Does this contraction require an inner product?
  22. num_terms : int
  23. The number of terms in a contraction
  24. size_dictionary : dict
  25. The size of each of the indices in idx_contraction
  26. Returns
  27. -------
  28. flop_count : int
  29. The total number of FLOPS required for the contraction.
  30. Examples
  31. --------
  32. >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
  33. 30
  34. >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
  35. 60
  36. """
  37. overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
  38. op_factor = max(1, num_terms - 1)
  39. if inner:
  40. op_factor += 1
  41. return overall_size * op_factor
  42. def _compute_size_by_dict(indices, idx_dict):
  43. """
  44. Computes the product of the elements in indices based on the dictionary
  45. idx_dict.
  46. Parameters
  47. ----------
  48. indices : iterable
  49. Indices to base the product on.
  50. idx_dict : dictionary
  51. Dictionary of index sizes
  52. Returns
  53. -------
  54. ret : int
  55. The resulting product.
  56. Examples
  57. --------
  58. >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
  59. 90
  60. """
  61. ret = 1
  62. for i in indices:
  63. ret *= idx_dict[i]
  64. return ret
  65. def _find_contraction(positions, input_sets, output_set):
  66. """
  67. Finds the contraction for a given set of input and output sets.
  68. Parameters
  69. ----------
  70. positions : iterable
  71. Integer positions of terms used in the contraction.
  72. input_sets : list
  73. List of sets that represent the lhs side of the einsum subscript
  74. output_set : set
  75. Set that represents the rhs side of the overall einsum subscript
  76. Returns
  77. -------
  78. new_result : set
  79. The indices of the resulting contraction
  80. remaining : list
  81. List of sets that have not been contracted, the new set is appended to
  82. the end of this list
  83. idx_removed : set
  84. Indices removed from the entire contraction
  85. idx_contraction : set
  86. The indices used in the current contraction
  87. Examples
  88. --------
  89. # A simple dot product test case
  90. >>> pos = (0, 1)
  91. >>> isets = [set('ab'), set('bc')]
  92. >>> oset = set('ac')
  93. >>> _find_contraction(pos, isets, oset)
  94. ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
  95. # A more complex case with additional terms in the contraction
  96. >>> pos = (0, 2)
  97. >>> isets = [set('abd'), set('ac'), set('bdc')]
  98. >>> oset = set('ac')
  99. >>> _find_contraction(pos, isets, oset)
  100. ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
  101. """
  102. idx_contract = set()
  103. idx_remain = output_set.copy()
  104. remaining = []
  105. for ind, value in enumerate(input_sets):
  106. if ind in positions:
  107. idx_contract |= value
  108. else:
  109. remaining.append(value)
  110. idx_remain |= value
  111. new_result = idx_remain & idx_contract
  112. idx_removed = (idx_contract - new_result)
  113. remaining.append(new_result)
  114. return (new_result, remaining, idx_removed, idx_contract)
  115. def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
  116. """
  117. Computes all possible pair contractions, sieves the results based
  118. on ``memory_limit`` and returns the lowest cost path. This algorithm
  119. scales factorial with respect to the elements in the list ``input_sets``.
  120. Parameters
  121. ----------
  122. input_sets : list
  123. List of sets that represent the lhs side of the einsum subscript
  124. output_set : set
  125. Set that represents the rhs side of the overall einsum subscript
  126. idx_dict : dictionary
  127. Dictionary of index sizes
  128. memory_limit : int
  129. The maximum number of elements in a temporary array
  130. Returns
  131. -------
  132. path : list
  133. The optimal contraction order within the memory limit constraint.
  134. Examples
  135. --------
  136. >>> isets = [set('abd'), set('ac'), set('bdc')]
  137. >>> oset = set()
  138. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  139. >>> _optimal_path(isets, oset, idx_sizes, 5000)
  140. [(0, 2), (0, 1)]
  141. """
  142. full_results = [(0, [], input_sets)]
  143. for iteration in range(len(input_sets) - 1):
  144. iter_results = []
  145. # Compute all unique pairs
  146. for curr in full_results:
  147. cost, positions, remaining = curr
  148. for con in itertools.combinations(range(len(input_sets) - iteration), 2):
  149. # Find the contraction
  150. cont = _find_contraction(con, remaining, output_set)
  151. new_result, new_input_sets, idx_removed, idx_contract = cont
  152. # Sieve the results based on memory_limit
  153. new_size = _compute_size_by_dict(new_result, idx_dict)
  154. if new_size > memory_limit:
  155. continue
  156. # Build (total_cost, positions, indices_remaining)
  157. total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
  158. new_pos = positions + [con]
  159. iter_results.append((total_cost, new_pos, new_input_sets))
  160. # Update combinatorial list, if we did not find anything return best
  161. # path + remaining contractions
  162. if iter_results:
  163. full_results = iter_results
  164. else:
  165. path = min(full_results, key=lambda x: x[0])[1]
  166. path += [tuple(range(len(input_sets) - iteration))]
  167. return path
  168. # If we have not found anything return single einsum contraction
  169. if len(full_results) == 0:
  170. return [tuple(range(len(input_sets)))]
  171. path = min(full_results, key=lambda x: x[0])[1]
  172. return path
  173. def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
  174. """Compute the cost (removed size + flops) and resultant indices for
  175. performing the contraction specified by ``positions``.
  176. Parameters
  177. ----------
  178. positions : tuple of int
  179. The locations of the proposed tensors to contract.
  180. input_sets : list of sets
  181. The indices found on each tensors.
  182. output_set : set
  183. The output indices of the expression.
  184. idx_dict : dict
  185. Mapping of each index to its size.
  186. memory_limit : int
  187. The total allowed size for an intermediary tensor.
  188. path_cost : int
  189. The contraction cost so far.
  190. naive_cost : int
  191. The cost of the unoptimized expression.
  192. Returns
  193. -------
  194. cost : (int, int)
  195. A tuple containing the size of any indices removed, and the flop cost.
  196. positions : tuple of int
  197. The locations of the proposed tensors to contract.
  198. new_input_sets : list of sets
  199. The resulting new list of indices if this proposed contraction is performed.
  200. """
  201. # Find the contraction
  202. contract = _find_contraction(positions, input_sets, output_set)
  203. idx_result, new_input_sets, idx_removed, idx_contract = contract
  204. # Sieve the results based on memory_limit
  205. new_size = _compute_size_by_dict(idx_result, idx_dict)
  206. if new_size > memory_limit:
  207. return None
  208. # Build sort tuple
  209. old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
  210. removed_size = sum(old_sizes) - new_size
  211. # NB: removed_size used to be just the size of any removed indices i.e.:
  212. # helpers.compute_size_by_dict(idx_removed, idx_dict)
  213. cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
  214. sort = (-removed_size, cost)
  215. # Sieve based on total cost as well
  216. if (path_cost + cost) > naive_cost:
  217. return None
  218. # Add contraction to possible choices
  219. return [sort, positions, new_input_sets]
  220. def _update_other_results(results, best):
  221. """Update the positions and provisional input_sets of ``results`` based on
  222. performing the contraction result ``best``. Remove any involving the tensors
  223. contracted.
  224. Parameters
  225. ----------
  226. results : list
  227. List of contraction results produced by ``_parse_possible_contraction``.
  228. best : list
  229. The best contraction of ``results`` i.e. the one that will be performed.
  230. Returns
  231. -------
  232. mod_results : list
  233. The list of modified results, updated with outcome of ``best`` contraction.
  234. """
  235. best_con = best[1]
  236. bx, by = best_con
  237. mod_results = []
  238. for cost, (x, y), con_sets in results:
  239. # Ignore results involving tensors just contracted
  240. if x in best_con or y in best_con:
  241. continue
  242. # Update the input_sets
  243. del con_sets[by - int(by > x) - int(by > y)]
  244. del con_sets[bx - int(bx > x) - int(bx > y)]
  245. con_sets.insert(-1, best[2][-1])
  246. # Update the position indices
  247. mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
  248. mod_results.append((cost, mod_con, con_sets))
  249. return mod_results
  250. def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
  251. """
  252. Finds the path by contracting the best pair until the input list is
  253. exhausted. The best pair is found by minimizing the tuple
  254. ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
  255. matrix multiplication or inner product operations, then Hadamard like
  256. operations, and finally outer operations. Outer products are limited by
  257. ``memory_limit``. This algorithm scales cubically with respect to the
  258. number of elements in the list ``input_sets``.
  259. Parameters
  260. ----------
  261. input_sets : list
  262. List of sets that represent the lhs side of the einsum subscript
  263. output_set : set
  264. Set that represents the rhs side of the overall einsum subscript
  265. idx_dict : dictionary
  266. Dictionary of index sizes
  267. memory_limit_limit : int
  268. The maximum number of elements in a temporary array
  269. Returns
  270. -------
  271. path : list
  272. The greedy contraction order within the memory limit constraint.
  273. Examples
  274. --------
  275. >>> isets = [set('abd'), set('ac'), set('bdc')]
  276. >>> oset = set()
  277. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  278. >>> _greedy_path(isets, oset, idx_sizes, 5000)
  279. [(0, 2), (0, 1)]
  280. """
  281. # Handle trivial cases that leaked through
  282. if len(input_sets) == 1:
  283. return [(0,)]
  284. elif len(input_sets) == 2:
  285. return [(0, 1)]
  286. # Build up a naive cost
  287. contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
  288. idx_result, new_input_sets, idx_removed, idx_contract = contract
  289. naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
  290. # Initially iterate over all pairs
  291. comb_iter = itertools.combinations(range(len(input_sets)), 2)
  292. known_contractions = []
  293. path_cost = 0
  294. path = []
  295. for iteration in range(len(input_sets) - 1):
  296. # Iterate over all pairs on first step, only previously found pairs on subsequent steps
  297. for positions in comb_iter:
  298. # Always initially ignore outer products
  299. if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
  300. continue
  301. result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
  302. naive_cost)
  303. if result is not None:
  304. known_contractions.append(result)
  305. # If we do not have a inner contraction, rescan pairs including outer products
  306. if len(known_contractions) == 0:
  307. # Then check the outer products
  308. for positions in itertools.combinations(range(len(input_sets)), 2):
  309. result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
  310. path_cost, naive_cost)
  311. if result is not None:
  312. known_contractions.append(result)
  313. # If we still did not find any remaining contractions, default back to einsum like behavior
  314. if len(known_contractions) == 0:
  315. path.append(tuple(range(len(input_sets))))
  316. break
  317. # Sort based on first index
  318. best = min(known_contractions, key=lambda x: x[0])
  319. # Now propagate as many unused contractions as possible to next iteration
  320. known_contractions = _update_other_results(known_contractions, best)
  321. # Next iteration only compute contractions with the new tensor
  322. # All other contractions have been accounted for
  323. input_sets = best[2]
  324. new_tensor_pos = len(input_sets) - 1
  325. comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
  326. # Update path and total cost
  327. path.append(best[1])
  328. path_cost += best[0][1]
  329. return path
  330. def _can_dot(inputs, result, idx_removed):
  331. """
  332. Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
  333. Parameters
  334. ----------
  335. inputs : list of str
  336. Specifies the subscripts for summation.
  337. result : str
  338. Resulting summation.
  339. idx_removed : set
  340. Indices that are removed in the summation
  341. Returns
  342. -------
  343. type : bool
  344. Returns true if BLAS should and can be used, else False
  345. Notes
  346. -----
  347. If the operations is BLAS level 1 or 2 and is not already aligned
  348. we default back to einsum as the memory movement to copy is more
  349. costly than the operation itself.
  350. Examples
  351. --------
  352. # Standard GEMM operation
  353. >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
  354. True
  355. # Can use the standard BLAS, but requires odd data movement
  356. >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
  357. False
  358. # DDOT where the memory is not aligned
  359. >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
  360. False
  361. """
  362. # All `dot` calls remove indices
  363. if len(idx_removed) == 0:
  364. return False
  365. # BLAS can only handle two operands
  366. if len(inputs) != 2:
  367. return False
  368. input_left, input_right = inputs
  369. for c in set(input_left + input_right):
  370. # can't deal with repeated indices on same input or more than 2 total
  371. nl, nr = input_left.count(c), input_right.count(c)
  372. if (nl > 1) or (nr > 1) or (nl + nr > 2):
  373. return False
  374. # can't do implicit summation or dimension collapse e.g.
  375. # "ab,bc->c" (implicitly sum over 'a')
  376. # "ab,ca->ca" (take diagonal of 'a')
  377. if nl + nr - 1 == int(c in result):
  378. return False
  379. # Build a few temporaries
  380. set_left = set(input_left)
  381. set_right = set(input_right)
  382. keep_left = set_left - idx_removed
  383. keep_right = set_right - idx_removed
  384. rs = len(idx_removed)
  385. # At this point we are a DOT, GEMV, or GEMM operation
  386. # Handle inner products
  387. # DDOT with aligned data
  388. if input_left == input_right:
  389. return True
  390. # DDOT without aligned data (better to use einsum)
  391. if set_left == set_right:
  392. return False
  393. # Handle the 4 possible (aligned) GEMV or GEMM cases
  394. # GEMM or GEMV no transpose
  395. if input_left[-rs:] == input_right[:rs]:
  396. return True
  397. # GEMM or GEMV transpose both
  398. if input_left[:rs] == input_right[-rs:]:
  399. return True
  400. # GEMM or GEMV transpose right
  401. if input_left[-rs:] == input_right[-rs:]:
  402. return True
  403. # GEMM or GEMV transpose left
  404. if input_left[:rs] == input_right[:rs]:
  405. return True
  406. # Einsum is faster than GEMV if we have to copy data
  407. if not keep_left or not keep_right:
  408. return False
  409. # We are a matrix-matrix product, but we need to copy data
  410. return True
  411. def _parse_einsum_input(operands):
  412. """
  413. A reproduction of einsum c side einsum parsing in python.
  414. Returns
  415. -------
  416. input_strings : str
  417. Parsed input strings
  418. output_string : str
  419. Parsed output string
  420. operands : list of array_like
  421. The operands to use in the numpy contraction
  422. Examples
  423. --------
  424. The operand list is simplified to reduce printing:
  425. >>> np.random.seed(123)
  426. >>> a = np.random.rand(4, 4)
  427. >>> b = np.random.rand(4, 4, 4)
  428. >>> _parse_einsum_input(('...a,...a->...', a, b))
  429. ('za,xza', 'xz', [a, b]) # may vary
  430. >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
  431. ('za,xza', 'xz', [a, b]) # may vary
  432. """
  433. if len(operands) == 0:
  434. raise ValueError("No input operands")
  435. if isinstance(operands[0], basestring):
  436. subscripts = operands[0].replace(" ", "")
  437. operands = [asanyarray(v) for v in operands[1:]]
  438. # Ensure all characters are valid
  439. for s in subscripts:
  440. if s in '.,->':
  441. continue
  442. if s not in einsum_symbols:
  443. raise ValueError("Character %s is not a valid symbol." % s)
  444. else:
  445. tmp_operands = list(operands)
  446. operand_list = []
  447. subscript_list = []
  448. for p in range(len(operands) // 2):
  449. operand_list.append(tmp_operands.pop(0))
  450. subscript_list.append(tmp_operands.pop(0))
  451. output_list = tmp_operands[-1] if len(tmp_operands) else None
  452. operands = [asanyarray(v) for v in operand_list]
  453. subscripts = ""
  454. last = len(subscript_list) - 1
  455. for num, sub in enumerate(subscript_list):
  456. for s in sub:
  457. if s is Ellipsis:
  458. subscripts += "..."
  459. elif isinstance(s, int):
  460. subscripts += einsum_symbols[s]
  461. else:
  462. raise TypeError("For this input type lists must contain "
  463. "either int or Ellipsis")
  464. if num != last:
  465. subscripts += ","
  466. if output_list is not None:
  467. subscripts += "->"
  468. for s in output_list:
  469. if s is Ellipsis:
  470. subscripts += "..."
  471. elif isinstance(s, int):
  472. subscripts += einsum_symbols[s]
  473. else:
  474. raise TypeError("For this input type lists must contain "
  475. "either int or Ellipsis")
  476. # Check for proper "->"
  477. if ("-" in subscripts) or (">" in subscripts):
  478. invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
  479. if invalid or (subscripts.count("->") != 1):
  480. raise ValueError("Subscripts can only contain one '->'.")
  481. # Parse ellipses
  482. if "." in subscripts:
  483. used = subscripts.replace(".", "").replace(",", "").replace("->", "")
  484. unused = list(einsum_symbols_set - set(used))
  485. ellipse_inds = "".join(unused)
  486. longest = 0
  487. if "->" in subscripts:
  488. input_tmp, output_sub = subscripts.split("->")
  489. split_subscripts = input_tmp.split(",")
  490. out_sub = True
  491. else:
  492. split_subscripts = subscripts.split(',')
  493. out_sub = False
  494. for num, sub in enumerate(split_subscripts):
  495. if "." in sub:
  496. if (sub.count(".") != 3) or (sub.count("...") != 1):
  497. raise ValueError("Invalid Ellipses.")
  498. # Take into account numerical values
  499. if operands[num].shape == ():
  500. ellipse_count = 0
  501. else:
  502. ellipse_count = max(operands[num].ndim, 1)
  503. ellipse_count -= (len(sub) - 3)
  504. if ellipse_count > longest:
  505. longest = ellipse_count
  506. if ellipse_count < 0:
  507. raise ValueError("Ellipses lengths do not match.")
  508. elif ellipse_count == 0:
  509. split_subscripts[num] = sub.replace('...', '')
  510. else:
  511. rep_inds = ellipse_inds[-ellipse_count:]
  512. split_subscripts[num] = sub.replace('...', rep_inds)
  513. subscripts = ",".join(split_subscripts)
  514. if longest == 0:
  515. out_ellipse = ""
  516. else:
  517. out_ellipse = ellipse_inds[-longest:]
  518. if out_sub:
  519. subscripts += "->" + output_sub.replace("...", out_ellipse)
  520. else:
  521. # Special care for outputless ellipses
  522. output_subscript = ""
  523. tmp_subscripts = subscripts.replace(",", "")
  524. for s in sorted(set(tmp_subscripts)):
  525. if s not in (einsum_symbols):
  526. raise ValueError("Character %s is not a valid symbol." % s)
  527. if tmp_subscripts.count(s) == 1:
  528. output_subscript += s
  529. normal_inds = ''.join(sorted(set(output_subscript) -
  530. set(out_ellipse)))
  531. subscripts += "->" + out_ellipse + normal_inds
  532. # Build output string if does not exist
  533. if "->" in subscripts:
  534. input_subscripts, output_subscript = subscripts.split("->")
  535. else:
  536. input_subscripts = subscripts
  537. # Build output subscripts
  538. tmp_subscripts = subscripts.replace(",", "")
  539. output_subscript = ""
  540. for s in sorted(set(tmp_subscripts)):
  541. if s not in einsum_symbols:
  542. raise ValueError("Character %s is not a valid symbol." % s)
  543. if tmp_subscripts.count(s) == 1:
  544. output_subscript += s
  545. # Make sure output subscripts are in the input
  546. for char in output_subscript:
  547. if char not in input_subscripts:
  548. raise ValueError("Output character %s did not appear in the input"
  549. % char)
  550. # Make sure number operands is equivalent to the number of terms
  551. if len(input_subscripts.split(',')) != len(operands):
  552. raise ValueError("Number of einsum subscripts must be equal to the "
  553. "number of operands.")
  554. return (input_subscripts, output_subscript, operands)
  555. def _einsum_path_dispatcher(*operands, **kwargs):
  556. # NOTE: technically, we should only dispatch on array-like arguments, not
  557. # subscripts (given as strings). But separating operands into
  558. # arrays/subscripts is a little tricky/slow (given einsum's two supported
  559. # signatures), so as a practical shortcut we dispatch on everything.
  560. # Strings will be ignored for dispatching since they don't define
  561. # __array_function__.
  562. return operands
  563. @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
  564. def einsum_path(*operands, **kwargs):
  565. """
  566. einsum_path(subscripts, *operands, optimize='greedy')
  567. Evaluates the lowest cost contraction order for an einsum expression by
  568. considering the creation of intermediate arrays.
  569. Parameters
  570. ----------
  571. subscripts : str
  572. Specifies the subscripts for summation.
  573. *operands : list of array_like
  574. These are the arrays for the operation.
  575. optimize : {bool, list, tuple, 'greedy', 'optimal'}
  576. Choose the type of path. If a tuple is provided, the second argument is
  577. assumed to be the maximum intermediate size created. If only a single
  578. argument is provided the largest input or output array size is used
  579. as a maximum intermediate size.
  580. * if a list is given that starts with ``einsum_path``, uses this as the
  581. contraction path
  582. * if False no optimization is taken
  583. * if True defaults to the 'greedy' algorithm
  584. * 'optimal' An algorithm that combinatorially explores all possible
  585. ways of contracting the listed tensors and choosest the least costly
  586. path. Scales exponentially with the number of terms in the
  587. contraction.
  588. * 'greedy' An algorithm that chooses the best pair contraction
  589. at each step. Effectively, this algorithm searches the largest inner,
  590. Hadamard, and then outer products at each step. Scales cubically with
  591. the number of terms in the contraction. Equivalent to the 'optimal'
  592. path for most contractions.
  593. Default is 'greedy'.
  594. Returns
  595. -------
  596. path : list of tuples
  597. A list representation of the einsum path.
  598. string_repr : str
  599. A printable representation of the einsum path.
  600. Notes
  601. -----
  602. The resulting path indicates which terms of the input contraction should be
  603. contracted first, the result of this contraction is then appended to the
  604. end of the contraction list. This list can then be iterated over until all
  605. intermediate contractions are complete.
  606. See Also
  607. --------
  608. einsum, linalg.multi_dot
  609. Examples
  610. --------
  611. We can begin with a chain dot example. In this case, it is optimal to
  612. contract the ``b`` and ``c`` tensors first as represented by the first
  613. element of the path ``(1, 2)``. The resulting tensor is added to the end
  614. of the contraction and the remaining contraction ``(0, 1)`` is then
  615. completed.
  616. >>> np.random.seed(123)
  617. >>> a = np.random.rand(2, 2)
  618. >>> b = np.random.rand(2, 5)
  619. >>> c = np.random.rand(5, 2)
  620. >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
  621. >>> print(path_info[0])
  622. ['einsum_path', (1, 2), (0, 1)]
  623. >>> print(path_info[1])
  624. Complete contraction: ij,jk,kl->il # may vary
  625. Naive scaling: 4
  626. Optimized scaling: 3
  627. Naive FLOP count: 1.600e+02
  628. Optimized FLOP count: 5.600e+01
  629. Theoretical speedup: 2.857
  630. Largest intermediate: 4.000e+00 elements
  631. -------------------------------------------------------------------------
  632. scaling current remaining
  633. -------------------------------------------------------------------------
  634. 3 kl,jk->jl ij,jl->il
  635. 3 jl,ij->il il->il
  636. A more complex index transformation example.
  637. >>> I = np.random.rand(10, 10, 10, 10)
  638. >>> C = np.random.rand(10, 10)
  639. >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
  640. ... optimize='greedy')
  641. >>> print(path_info[0])
  642. ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
  643. >>> print(path_info[1])
  644. Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
  645. Naive scaling: 8
  646. Optimized scaling: 5
  647. Naive FLOP count: 8.000e+08
  648. Optimized FLOP count: 8.000e+05
  649. Theoretical speedup: 1000.000
  650. Largest intermediate: 1.000e+04 elements
  651. --------------------------------------------------------------------------
  652. scaling current remaining
  653. --------------------------------------------------------------------------
  654. 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
  655. 5 bcde,fb->cdef gc,hd,cdef->efgh
  656. 5 cdef,gc->defg hd,defg->efgh
  657. 5 defg,hd->efgh efgh->efgh
  658. """
  659. # Make sure all keywords are valid
  660. valid_contract_kwargs = ['optimize', 'einsum_call']
  661. unknown_kwargs = [k for (k, v) in kwargs.items() if k
  662. not in valid_contract_kwargs]
  663. if len(unknown_kwargs):
  664. raise TypeError("Did not understand the following kwargs:"
  665. " %s" % unknown_kwargs)
  666. # Figure out what the path really is
  667. path_type = kwargs.pop('optimize', True)
  668. if path_type is True:
  669. path_type = 'greedy'
  670. if path_type is None:
  671. path_type = False
  672. memory_limit = None
  673. # No optimization or a named path algorithm
  674. if (path_type is False) or isinstance(path_type, basestring):
  675. pass
  676. # Given an explicit path
  677. elif len(path_type) and (path_type[0] == 'einsum_path'):
  678. pass
  679. # Path tuple with memory limit
  680. elif ((len(path_type) == 2) and isinstance(path_type[0], basestring) and
  681. isinstance(path_type[1], (int, float))):
  682. memory_limit = int(path_type[1])
  683. path_type = path_type[0]
  684. else:
  685. raise TypeError("Did not understand the path: %s" % str(path_type))
  686. # Hidden option, only einsum should call this
  687. einsum_call_arg = kwargs.pop("einsum_call", False)
  688. # Python side parsing
  689. input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
  690. # Build a few useful list and sets
  691. input_list = input_subscripts.split(',')
  692. input_sets = [set(x) for x in input_list]
  693. output_set = set(output_subscript)
  694. indices = set(input_subscripts.replace(',', ''))
  695. # Get length of each unique dimension and ensure all dimensions are correct
  696. dimension_dict = {}
  697. broadcast_indices = [[] for x in range(len(input_list))]
  698. for tnum, term in enumerate(input_list):
  699. sh = operands[tnum].shape
  700. if len(sh) != len(term):
  701. raise ValueError("Einstein sum subscript %s does not contain the "
  702. "correct number of indices for operand %d."
  703. % (input_subscripts[tnum], tnum))
  704. for cnum, char in enumerate(term):
  705. dim = sh[cnum]
  706. # Build out broadcast indices
  707. if dim == 1:
  708. broadcast_indices[tnum].append(char)
  709. if char in dimension_dict.keys():
  710. # For broadcasting cases we always want the largest dim size
  711. if dimension_dict[char] == 1:
  712. dimension_dict[char] = dim
  713. elif dim not in (1, dimension_dict[char]):
  714. raise ValueError("Size of label '%s' for operand %d (%d) "
  715. "does not match previous terms (%d)."
  716. % (char, tnum, dimension_dict[char], dim))
  717. else:
  718. dimension_dict[char] = dim
  719. # Convert broadcast inds to sets
  720. broadcast_indices = [set(x) for x in broadcast_indices]
  721. # Compute size of each input array plus the output array
  722. size_list = [_compute_size_by_dict(term, dimension_dict)
  723. for term in input_list + [output_subscript]]
  724. max_size = max(size_list)
  725. if memory_limit is None:
  726. memory_arg = max_size
  727. else:
  728. memory_arg = memory_limit
  729. # Compute naive cost
  730. # This isn't quite right, need to look into exactly how einsum does this
  731. inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
  732. naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
  733. # Compute the path
  734. if (path_type is False) or (len(input_list) in [1, 2]) or (indices == output_set):
  735. # Nothing to be optimized, leave it to einsum
  736. path = [tuple(range(len(input_list)))]
  737. elif path_type == "greedy":
  738. path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
  739. elif path_type == "optimal":
  740. path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
  741. elif path_type[0] == 'einsum_path':
  742. path = path_type[1:]
  743. else:
  744. raise KeyError("Path name %s not found", path_type)
  745. cost_list, scale_list, size_list, contraction_list = [], [], [], []
  746. # Build contraction tuple (positions, gemm, einsum_str, remaining)
  747. for cnum, contract_inds in enumerate(path):
  748. # Make sure we remove inds from right to left
  749. contract_inds = tuple(sorted(list(contract_inds), reverse=True))
  750. contract = _find_contraction(contract_inds, input_sets, output_set)
  751. out_inds, input_sets, idx_removed, idx_contract = contract
  752. cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
  753. cost_list.append(cost)
  754. scale_list.append(len(idx_contract))
  755. size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
  756. bcast = set()
  757. tmp_inputs = []
  758. for x in contract_inds:
  759. tmp_inputs.append(input_list.pop(x))
  760. bcast |= broadcast_indices.pop(x)
  761. new_bcast_inds = bcast - idx_removed
  762. # If we're broadcasting, nix blas
  763. if not len(idx_removed & bcast):
  764. do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
  765. else:
  766. do_blas = False
  767. # Last contraction
  768. if (cnum - len(path)) == -1:
  769. idx_result = output_subscript
  770. else:
  771. sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
  772. idx_result = "".join([x[1] for x in sorted(sort_result)])
  773. input_list.append(idx_result)
  774. broadcast_indices.append(new_bcast_inds)
  775. einsum_str = ",".join(tmp_inputs) + "->" + idx_result
  776. contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
  777. contraction_list.append(contraction)
  778. opt_cost = sum(cost_list) + 1
  779. if einsum_call_arg:
  780. return (operands, contraction_list)
  781. # Return the path along with a nice string representation
  782. overall_contraction = input_subscripts + "->" + output_subscript
  783. header = ("scaling", "current", "remaining")
  784. speedup = naive_cost / opt_cost
  785. max_i = max(size_list)
  786. path_print = " Complete contraction: %s\n" % overall_contraction
  787. path_print += " Naive scaling: %d\n" % len(indices)
  788. path_print += " Optimized scaling: %d\n" % max(scale_list)
  789. path_print += " Naive FLOP count: %.3e\n" % naive_cost
  790. path_print += " Optimized FLOP count: %.3e\n" % opt_cost
  791. path_print += " Theoretical speedup: %3.3f\n" % speedup
  792. path_print += " Largest intermediate: %.3e elements\n" % max_i
  793. path_print += "-" * 74 + "\n"
  794. path_print += "%6s %24s %40s\n" % header
  795. path_print += "-" * 74
  796. for n, contraction in enumerate(contraction_list):
  797. inds, idx_rm, einsum_str, remaining, blas = contraction
  798. remaining_str = ",".join(remaining) + "->" + output_subscript
  799. path_run = (scale_list[n], einsum_str, remaining_str)
  800. path_print += "\n%4d %24s %40s" % path_run
  801. path = ['einsum_path'] + path
  802. return (path, path_print)
  803. def _einsum_dispatcher(*operands, **kwargs):
  804. # Arguably we dispatch on more arguments that we really should; see note in
  805. # _einsum_path_dispatcher for why.
  806. for op in operands:
  807. yield op
  808. yield kwargs.get('out')
  809. # Rewrite einsum to handle different cases
  810. @array_function_dispatch(_einsum_dispatcher, module='numpy')
  811. def einsum(*operands, **kwargs):
  812. """
  813. einsum(subscripts, *operands, out=None, dtype=None, order='K',
  814. casting='safe', optimize=False)
  815. Evaluates the Einstein summation convention on the operands.
  816. Using the Einstein summation convention, many common multi-dimensional,
  817. linear algebraic array operations can be represented in a simple fashion.
  818. In *implicit* mode `einsum` computes these values.
  819. In *explicit* mode, `einsum` provides further flexibility to compute
  820. other array operations that might not be considered classical Einstein
  821. summation operations, by disabling, or forcing summation over specified
  822. subscript labels.
  823. See the notes and examples for clarification.
  824. Parameters
  825. ----------
  826. subscripts : str
  827. Specifies the subscripts for summation as comma separated list of
  828. subscript labels. An implicit (classical Einstein summation)
  829. calculation is performed unless the explicit indicator '->' is
  830. included as well as subscript labels of the precise output form.
  831. operands : list of array_like
  832. These are the arrays for the operation.
  833. out : ndarray, optional
  834. If provided, the calculation is done into this array.
  835. dtype : {data-type, None}, optional
  836. If provided, forces the calculation to use the data type specified.
  837. Note that you may have to also give a more liberal `casting`
  838. parameter to allow the conversions. Default is None.
  839. order : {'C', 'F', 'A', 'K'}, optional
  840. Controls the memory layout of the output. 'C' means it should
  841. be C contiguous. 'F' means it should be Fortran contiguous,
  842. 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
  843. 'K' means it should be as close to the layout as the inputs as
  844. is possible, including arbitrarily permuted axes.
  845. Default is 'K'.
  846. casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
  847. Controls what kind of data casting may occur. Setting this to
  848. 'unsafe' is not recommended, as it can adversely affect accumulations.
  849. * 'no' means the data types should not be cast at all.
  850. * 'equiv' means only byte-order changes are allowed.
  851. * 'safe' means only casts which can preserve values are allowed.
  852. * 'same_kind' means only safe casts or casts within a kind,
  853. like float64 to float32, are allowed.
  854. * 'unsafe' means any data conversions may be done.
  855. Default is 'safe'.
  856. optimize : {False, True, 'greedy', 'optimal'}, optional
  857. Controls if intermediate optimization should occur. No optimization
  858. will occur if False and True will default to the 'greedy' algorithm.
  859. Also accepts an explicit contraction list from the ``np.einsum_path``
  860. function. See ``np.einsum_path`` for more details. Defaults to False.
  861. Returns
  862. -------
  863. output : ndarray
  864. The calculation based on the Einstein summation convention.
  865. See Also
  866. --------
  867. einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
  868. Notes
  869. -----
  870. .. versionadded:: 1.6.0
  871. The Einstein summation convention can be used to compute
  872. many multi-dimensional, linear algebraic array operations. `einsum`
  873. provides a succinct way of representing these.
  874. A non-exhaustive list of these operations,
  875. which can be computed by `einsum`, is shown below along with examples:
  876. * Trace of an array, :py:func:`numpy.trace`.
  877. * Return a diagonal, :py:func:`numpy.diag`.
  878. * Array axis summations, :py:func:`numpy.sum`.
  879. * Transpositions and permutations, :py:func:`numpy.transpose`.
  880. * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`.
  881. * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`.
  882. * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`.
  883. * Tensor contractions, :py:func:`numpy.tensordot`.
  884. * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`.
  885. The subscripts string is a comma-separated list of subscript labels,
  886. where each label refers to a dimension of the corresponding operand.
  887. Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
  888. is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
  889. appears only once, it is not summed, so ``np.einsum('i', a)`` produces a
  890. view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)``
  891. describes traditional matrix multiplication and is equivalent to
  892. :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one
  893. operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent
  894. to :py:func:`np.trace(a) <numpy.trace>`.
  895. In *implicit mode*, the chosen subscripts are important
  896. since the axes of the output are reordered alphabetically. This
  897. means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
  898. ``np.einsum('ji', a)`` takes its transpose. Additionally,
  899. ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
  900. ``np.einsum('ij,jh', a, b)`` returns the transpose of the
  901. multiplication since subscript 'h' precedes subscript 'i'.
  902. In *explicit mode* the output can be directly controlled by
  903. specifying output subscript labels. This requires the
  904. identifier '->' as well as the list of output subscript labels.
  905. This feature increases the flexibility of the function since
  906. summing can be disabled or forced when required. The call
  907. ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`,
  908. and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`.
  909. The difference is that `einsum` does not allow broadcasting by default.
  910. Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
  911. order of the output subscript labels and therefore returns matrix
  912. multiplication, unlike the example above in implicit mode.
  913. To enable and control broadcasting, use an ellipsis. Default
  914. NumPy-style broadcasting is done by adding an ellipsis
  915. to the left of each term, like ``np.einsum('...ii->...i', a)``.
  916. To take the trace along the first and last axes,
  917. you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
  918. product with the left-most indices instead of rightmost, one can do
  919. ``np.einsum('ij...,jk...->ik...', a, b)``.
  920. When there is only one operand, no axes are summed, and no output
  921. parameter is provided, a view into the operand is returned instead
  922. of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
  923. produces a view (changed in version 1.10.0).
  924. `einsum` also provides an alternative way to provide the subscripts
  925. and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
  926. If the output shape is not provided in this format `einsum` will be
  927. calculated in implicit mode, otherwise it will be performed explicitly.
  928. The examples below have corresponding `einsum` calls with the two
  929. parameter methods.
  930. .. versionadded:: 1.10.0
  931. Views returned from einsum are now writeable whenever the input array
  932. is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
  933. have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
  934. and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
  935. of a 2D array.
  936. .. versionadded:: 1.12.0
  937. Added the ``optimize`` argument which will optimize the contraction order
  938. of an einsum expression. For a contraction with three or more operands this
  939. can greatly increase the computational efficiency at the cost of a larger
  940. memory footprint during computation.
  941. Typically a 'greedy' algorithm is applied which empirical tests have shown
  942. returns the optimal path in the majority of cases. In some cases 'optimal'
  943. will return the superlative path through a more expensive, exhaustive search.
  944. For iterative calculations it may be advisable to calculate the optimal path
  945. once and reuse that path by supplying it as an argument. An example is given
  946. below.
  947. See :py:func:`numpy.einsum_path` for more details.
  948. Examples
  949. --------
  950. >>> a = np.arange(25).reshape(5,5)
  951. >>> b = np.arange(5)
  952. >>> c = np.arange(6).reshape(2,3)
  953. Trace of a matrix:
  954. >>> np.einsum('ii', a)
  955. 60
  956. >>> np.einsum(a, [0,0])
  957. 60
  958. >>> np.trace(a)
  959. 60
  960. Extract the diagonal (requires explicit form):
  961. >>> np.einsum('ii->i', a)
  962. array([ 0, 6, 12, 18, 24])
  963. >>> np.einsum(a, [0,0], [0])
  964. array([ 0, 6, 12, 18, 24])
  965. >>> np.diag(a)
  966. array([ 0, 6, 12, 18, 24])
  967. Sum over an axis (requires explicit form):
  968. >>> np.einsum('ij->i', a)
  969. array([ 10, 35, 60, 85, 110])
  970. >>> np.einsum(a, [0,1], [0])
  971. array([ 10, 35, 60, 85, 110])
  972. >>> np.sum(a, axis=1)
  973. array([ 10, 35, 60, 85, 110])
  974. For higher dimensional arrays summing a single axis can be done with ellipsis:
  975. >>> np.einsum('...j->...', a)
  976. array([ 10, 35, 60, 85, 110])
  977. >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
  978. array([ 10, 35, 60, 85, 110])
  979. Compute a matrix transpose, or reorder any number of axes:
  980. >>> np.einsum('ji', c)
  981. array([[0, 3],
  982. [1, 4],
  983. [2, 5]])
  984. >>> np.einsum('ij->ji', c)
  985. array([[0, 3],
  986. [1, 4],
  987. [2, 5]])
  988. >>> np.einsum(c, [1,0])
  989. array([[0, 3],
  990. [1, 4],
  991. [2, 5]])
  992. >>> np.transpose(c)
  993. array([[0, 3],
  994. [1, 4],
  995. [2, 5]])
  996. Vector inner products:
  997. >>> np.einsum('i,i', b, b)
  998. 30
  999. >>> np.einsum(b, [0], b, [0])
  1000. 30
  1001. >>> np.inner(b,b)
  1002. 30
  1003. Matrix vector multiplication:
  1004. >>> np.einsum('ij,j', a, b)
  1005. array([ 30, 80, 130, 180, 230])
  1006. >>> np.einsum(a, [0,1], b, [1])
  1007. array([ 30, 80, 130, 180, 230])
  1008. >>> np.dot(a, b)
  1009. array([ 30, 80, 130, 180, 230])
  1010. >>> np.einsum('...j,j', a, b)
  1011. array([ 30, 80, 130, 180, 230])
  1012. Broadcasting and scalar multiplication:
  1013. >>> np.einsum('..., ...', 3, c)
  1014. array([[ 0, 3, 6],
  1015. [ 9, 12, 15]])
  1016. >>> np.einsum(',ij', 3, c)
  1017. array([[ 0, 3, 6],
  1018. [ 9, 12, 15]])
  1019. >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
  1020. array([[ 0, 3, 6],
  1021. [ 9, 12, 15]])
  1022. >>> np.multiply(3, c)
  1023. array([[ 0, 3, 6],
  1024. [ 9, 12, 15]])
  1025. Vector outer product:
  1026. >>> np.einsum('i,j', np.arange(2)+1, b)
  1027. array([[0, 1, 2, 3, 4],
  1028. [0, 2, 4, 6, 8]])
  1029. >>> np.einsum(np.arange(2)+1, [0], b, [1])
  1030. array([[0, 1, 2, 3, 4],
  1031. [0, 2, 4, 6, 8]])
  1032. >>> np.outer(np.arange(2)+1, b)
  1033. array([[0, 1, 2, 3, 4],
  1034. [0, 2, 4, 6, 8]])
  1035. Tensor contraction:
  1036. >>> a = np.arange(60.).reshape(3,4,5)
  1037. >>> b = np.arange(24.).reshape(4,3,2)
  1038. >>> np.einsum('ijk,jil->kl', a, b)
  1039. array([[4400., 4730.],
  1040. [4532., 4874.],
  1041. [4664., 5018.],
  1042. [4796., 5162.],
  1043. [4928., 5306.]])
  1044. >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
  1045. array([[4400., 4730.],
  1046. [4532., 4874.],
  1047. [4664., 5018.],
  1048. [4796., 5162.],
  1049. [4928., 5306.]])
  1050. >>> np.tensordot(a,b, axes=([1,0],[0,1]))
  1051. array([[4400., 4730.],
  1052. [4532., 4874.],
  1053. [4664., 5018.],
  1054. [4796., 5162.],
  1055. [4928., 5306.]])
  1056. Writeable returned arrays (since version 1.10.0):
  1057. >>> a = np.zeros((3, 3))
  1058. >>> np.einsum('ii->i', a)[:] = 1
  1059. >>> a
  1060. array([[1., 0., 0.],
  1061. [0., 1., 0.],
  1062. [0., 0., 1.]])
  1063. Example of ellipsis use:
  1064. >>> a = np.arange(6).reshape((3,2))
  1065. >>> b = np.arange(12).reshape((4,3))
  1066. >>> np.einsum('ki,jk->ij', a, b)
  1067. array([[10, 28, 46, 64],
  1068. [13, 40, 67, 94]])
  1069. >>> np.einsum('ki,...k->i...', a, b)
  1070. array([[10, 28, 46, 64],
  1071. [13, 40, 67, 94]])
  1072. >>> np.einsum('k...,jk', a, b)
  1073. array([[10, 28, 46, 64],
  1074. [13, 40, 67, 94]])
  1075. Chained array operations. For more complicated contractions, speed ups
  1076. might be achieved by repeatedly computing a 'greedy' path or pre-computing the
  1077. 'optimal' path and repeatedly applying it, using an
  1078. `einsum_path` insertion (since version 1.12.0). Performance improvements can be
  1079. particularly significant with larger arrays:
  1080. >>> a = np.ones(64).reshape(2,4,8)
  1081. Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
  1082. >>> for iteration in range(500):
  1083. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
  1084. Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
  1085. >>> for iteration in range(500):
  1086. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
  1087. Greedy `einsum` (faster optimal path approximation): ~160ms
  1088. >>> for iteration in range(500):
  1089. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
  1090. Optimal `einsum` (best usage pattern in some use cases): ~110ms
  1091. >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
  1092. >>> for iteration in range(500):
  1093. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
  1094. """
  1095. # Grab non-einsum kwargs; do not optimize by default.
  1096. optimize_arg = kwargs.pop('optimize', False)
  1097. # If no optimization, run pure einsum
  1098. if optimize_arg is False:
  1099. return c_einsum(*operands, **kwargs)
  1100. valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']
  1101. einsum_kwargs = {k: v for (k, v) in kwargs.items() if
  1102. k in valid_einsum_kwargs}
  1103. # Make sure all keywords are valid
  1104. valid_contract_kwargs = ['optimize'] + valid_einsum_kwargs
  1105. unknown_kwargs = [k for (k, v) in kwargs.items() if
  1106. k not in valid_contract_kwargs]
  1107. if len(unknown_kwargs):
  1108. raise TypeError("Did not understand the following kwargs: %s"
  1109. % unknown_kwargs)
  1110. # Special handeling if out is specified
  1111. specified_out = False
  1112. out_array = einsum_kwargs.pop('out', None)
  1113. if out_array is not None:
  1114. specified_out = True
  1115. # Build the contraction list and operand
  1116. operands, contraction_list = einsum_path(*operands, optimize=optimize_arg,
  1117. einsum_call=True)
  1118. handle_out = False
  1119. # Start contraction loop
  1120. for num, contraction in enumerate(contraction_list):
  1121. inds, idx_rm, einsum_str, remaining, blas = contraction
  1122. tmp_operands = [operands.pop(x) for x in inds]
  1123. # Do we need to deal with the output?
  1124. handle_out = specified_out and ((num + 1) == len(contraction_list))
  1125. # Call tensordot if still possible
  1126. if blas:
  1127. # Checks have already been handled
  1128. input_str, results_index = einsum_str.split('->')
  1129. input_left, input_right = input_str.split(',')
  1130. tensor_result = input_left + input_right
  1131. for s in idx_rm:
  1132. tensor_result = tensor_result.replace(s, "")
  1133. # Find indices to contract over
  1134. left_pos, right_pos = [], []
  1135. for s in sorted(idx_rm):
  1136. left_pos.append(input_left.find(s))
  1137. right_pos.append(input_right.find(s))
  1138. # Contract!
  1139. new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
  1140. # Build a new view if needed
  1141. if (tensor_result != results_index) or handle_out:
  1142. if handle_out:
  1143. einsum_kwargs["out"] = out_array
  1144. new_view = c_einsum(tensor_result + '->' + results_index, new_view, **einsum_kwargs)
  1145. # Call einsum
  1146. else:
  1147. # If out was specified
  1148. if handle_out:
  1149. einsum_kwargs["out"] = out_array
  1150. # Do the contraction
  1151. new_view = c_einsum(einsum_str, *tmp_operands, **einsum_kwargs)
  1152. # Append new items and dereference what we can
  1153. operands.append(new_view)
  1154. del tmp_operands, new_view
  1155. if specified_out:
  1156. return out_array
  1157. else:
  1158. return operands[0]