test_polyutils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. """Tests for polyutils module.
  2. """
  3. from __future__ import division, absolute_import, print_function
  4. import numpy as np
  5. import numpy.polynomial.polyutils as pu
  6. from numpy.testing import (
  7. assert_almost_equal, assert_raises, assert_equal, assert_,
  8. )
  9. class TestMisc(object):
  10. def test_trimseq(self):
  11. for i in range(5):
  12. tgt = [1]
  13. res = pu.trimseq([1] + [0]*5)
  14. assert_equal(res, tgt)
  15. def test_as_series(self):
  16. # check exceptions
  17. assert_raises(ValueError, pu.as_series, [[]])
  18. assert_raises(ValueError, pu.as_series, [[[1, 2]]])
  19. assert_raises(ValueError, pu.as_series, [[1], ['a']])
  20. # check common types
  21. types = ['i', 'd', 'O']
  22. for i in range(len(types)):
  23. for j in range(i):
  24. ci = np.ones(1, types[i])
  25. cj = np.ones(1, types[j])
  26. [resi, resj] = pu.as_series([ci, cj])
  27. assert_(resi.dtype.char == resj.dtype.char)
  28. assert_(resj.dtype.char == types[i])
  29. def test_trimcoef(self):
  30. coef = [2, -1, 1, 0]
  31. # Test exceptions
  32. assert_raises(ValueError, pu.trimcoef, coef, -1)
  33. # Test results
  34. assert_equal(pu.trimcoef(coef), coef[:-1])
  35. assert_equal(pu.trimcoef(coef, 1), coef[:-3])
  36. assert_equal(pu.trimcoef(coef, 2), [0])
  37. class TestDomain(object):
  38. def test_getdomain(self):
  39. # test for real values
  40. x = [1, 10, 3, -1]
  41. tgt = [-1, 10]
  42. res = pu.getdomain(x)
  43. assert_almost_equal(res, tgt)
  44. # test for complex values
  45. x = [1 + 1j, 1 - 1j, 0, 2]
  46. tgt = [-1j, 2 + 1j]
  47. res = pu.getdomain(x)
  48. assert_almost_equal(res, tgt)
  49. def test_mapdomain(self):
  50. # test for real values
  51. dom1 = [0, 4]
  52. dom2 = [1, 3]
  53. tgt = dom2
  54. res = pu.mapdomain(dom1, dom1, dom2)
  55. assert_almost_equal(res, tgt)
  56. # test for complex values
  57. dom1 = [0 - 1j, 2 + 1j]
  58. dom2 = [-2, 2]
  59. tgt = dom2
  60. x = dom1
  61. res = pu.mapdomain(x, dom1, dom2)
  62. assert_almost_equal(res, tgt)
  63. # test for multidimensional arrays
  64. dom1 = [0, 4]
  65. dom2 = [1, 3]
  66. tgt = np.array([dom2, dom2])
  67. x = np.array([dom1, dom1])
  68. res = pu.mapdomain(x, dom1, dom2)
  69. assert_almost_equal(res, tgt)
  70. # test that subtypes are preserved.
  71. class MyNDArray(np.ndarray):
  72. pass
  73. dom1 = [0, 4]
  74. dom2 = [1, 3]
  75. x = np.array([dom1, dom1]).view(MyNDArray)
  76. res = pu.mapdomain(x, dom1, dom2)
  77. assert_(isinstance(res, MyNDArray))
  78. def test_mapparms(self):
  79. # test for real values
  80. dom1 = [0, 4]
  81. dom2 = [1, 3]
  82. tgt = [1, .5]
  83. res = pu. mapparms(dom1, dom2)
  84. assert_almost_equal(res, tgt)
  85. # test for complex values
  86. dom1 = [0 - 1j, 2 + 1j]
  87. dom2 = [-2, 2]
  88. tgt = [-1 + 1j, 1 - 1j]
  89. res = pu.mapparms(dom1, dom2)
  90. assert_almost_equal(res, tgt)