test_matrix_linalg.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. """ Test functions for linalg module using the matrix class."""
  2. from __future__ import division, absolute_import, print_function
  3. import numpy as np
  4. from numpy.linalg.tests.test_linalg import (
  5. LinalgCase, apply_tag, TestQR as _TestQR, LinalgTestCase,
  6. _TestNorm2D, _TestNormDoubleBase, _TestNormSingleBase, _TestNormInt64Base,
  7. SolveCases, InvCases, EigvalsCases, EigCases, SVDCases, CondCases,
  8. PinvCases, DetCases, LstsqCases)
  9. CASES = []
  10. # square test cases
  11. CASES += apply_tag('square', [
  12. LinalgCase("0x0_matrix",
  13. np.empty((0, 0), dtype=np.double).view(np.matrix),
  14. np.empty((0, 1), dtype=np.double).view(np.matrix),
  15. tags={'size-0'}),
  16. LinalgCase("matrix_b_only",
  17. np.array([[1., 2.], [3., 4.]]),
  18. np.matrix([2., 1.]).T),
  19. LinalgCase("matrix_a_and_b",
  20. np.matrix([[1., 2.], [3., 4.]]),
  21. np.matrix([2., 1.]).T),
  22. ])
  23. # hermitian test-cases
  24. CASES += apply_tag('hermitian', [
  25. LinalgCase("hmatrix_a_and_b",
  26. np.matrix([[1., 2.], [2., 1.]]),
  27. None),
  28. ])
  29. # No need to make generalized or strided cases for matrices.
  30. class MatrixTestCase(LinalgTestCase):
  31. TEST_CASES = CASES
  32. class TestSolveMatrix(SolveCases, MatrixTestCase):
  33. pass
  34. class TestInvMatrix(InvCases, MatrixTestCase):
  35. pass
  36. class TestEigvalsMatrix(EigvalsCases, MatrixTestCase):
  37. pass
  38. class TestEigMatrix(EigCases, MatrixTestCase):
  39. pass
  40. class TestSVDMatrix(SVDCases, MatrixTestCase):
  41. pass
  42. class TestCondMatrix(CondCases, MatrixTestCase):
  43. pass
  44. class TestPinvMatrix(PinvCases, MatrixTestCase):
  45. pass
  46. class TestDetMatrix(DetCases, MatrixTestCase):
  47. pass
  48. class TestLstsqMatrix(LstsqCases, MatrixTestCase):
  49. pass
  50. class _TestNorm2DMatrix(_TestNorm2D):
  51. array = np.matrix
  52. class TestNormDoubleMatrix(_TestNorm2DMatrix, _TestNormDoubleBase):
  53. pass
  54. class TestNormSingleMatrix(_TestNorm2DMatrix, _TestNormSingleBase):
  55. pass
  56. class TestNormInt64Matrix(_TestNorm2DMatrix, _TestNormInt64Base):
  57. pass
  58. class TestQRMatrix(_TestQR):
  59. array = np.matrix