common.py 1.1 KB

12345678910111213141516171819202122232425262728293031
  1. def _check_mixed_float(df, dtype=None):
  2. # float16 are most likely to be upcasted to float32
  3. dtypes = dict(A="float32", B="float32", C="float16", D="float64")
  4. if isinstance(dtype, str):
  5. dtypes = {k: dtype for k, v in dtypes.items()}
  6. elif isinstance(dtype, dict):
  7. dtypes.update(dtype)
  8. if dtypes.get("A"):
  9. assert df.dtypes["A"] == dtypes["A"]
  10. if dtypes.get("B"):
  11. assert df.dtypes["B"] == dtypes["B"]
  12. if dtypes.get("C"):
  13. assert df.dtypes["C"] == dtypes["C"]
  14. if dtypes.get("D"):
  15. assert df.dtypes["D"] == dtypes["D"]
  16. def _check_mixed_int(df, dtype=None):
  17. dtypes = dict(A="int32", B="uint64", C="uint8", D="int64")
  18. if isinstance(dtype, str):
  19. dtypes = {k: dtype for k, v in dtypes.items()}
  20. elif isinstance(dtype, dict):
  21. dtypes.update(dtype)
  22. if dtypes.get("A"):
  23. assert df.dtypes["A"] == dtypes["A"]
  24. if dtypes.get("B"):
  25. assert df.dtypes["B"] == dtypes["B"]
  26. if dtypes.get("C"):
  27. assert df.dtypes["C"] == dtypes["C"]
  28. if dtypes.get("D"):
  29. assert df.dtypes["D"] == dtypes["D"]