1
0

user.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from sql import db
  2. from sql.base import DBBit
  3. import object.user
  4. from typing import List
  5. role_authority = ["WriteBlog", "WriteComment", "WriteMsg", "CreateUser",
  6. "ReadBlog", "ReadComment", "ReadMsg", "ReadSecretMsg", "ReadUserInfo",
  7. "DeleteBlog", "DeleteComment", "DeleteMsg", "DeleteUser",
  8. "ConfigureSystem", "ReadSystem"]
  9. def read_user(email: str):
  10. """ 读取用户 """
  11. cur = db.search("SELECT PasswdHash, Role, ID FROM user WHERE Email=%s", email)
  12. if cur is None or cur.rowcount != 1:
  13. return ["", -1, -1]
  14. return cur.fetchone()
  15. def create_user(email: str, passwd: str):
  16. """ 创建用户 """
  17. email = email.replace("'", "''")
  18. if len(email) == 0:
  19. return None
  20. cur = db.search("SELECT COUNT(*) FROM user")
  21. passwd = object.user.User.get_passwd_hash(passwd)
  22. if cur is None or cur.rowcount == 0 or cur.fetchone()[0] == 0:
  23. # 创建为管理员用户
  24. cur = db.insert("INSERT INTO user(Email, PasswdHash, Role) "
  25. "VALUES (%s, %s, %s)", email, passwd, 1)
  26. else:
  27. cur = db.insert("INSERT INTO user(Email, PasswdHash) "
  28. "VALUES (%s, %s)", email, passwd)
  29. if cur is None or cur.rowcount != 1:
  30. return None
  31. return cur.lastrowid
  32. def delete_user(user_id: int):
  33. """ 删除用户 """
  34. cur = db.delete("DELETE FROM message WHERE Auth=%s", user_id)
  35. if cur is None:
  36. return False
  37. cur = db.delete("DELETE FROM comment WHERE Auth=%s", user_id)
  38. if cur is None:
  39. return False
  40. cur = db.delete("DELETE FROM blog WHERE Auth=%s", user_id)
  41. if cur is None:
  42. return False
  43. cur = db.delete("DELETE FROM user WHERE ID=%s", user_id)
  44. if cur is None or cur.rowcount == 0:
  45. return False
  46. return True
  47. def __authority_to_sql(authority):
  48. """ authority 转换为 Update语句, 不检查合法性 """
  49. sql = []
  50. args = []
  51. for i in authority:
  52. sql.append(f"{i}=%s")
  53. args.append(authority[i])
  54. return ",".join(sql), args
  55. def create_role(name: str, authority: List[str]):
  56. name = name.replace("'", "''")
  57. cur = db.insert("INSERT INTO role(RoleName) VALUES (%s)", name)
  58. if cur is None or cur.rowcount == 0:
  59. return False
  60. sql, args = __authority_to_sql({i: (1 if i in authority else 0) for i in role_authority})
  61. cur = db.update(f"UPDATE role "
  62. f"SET {sql} "
  63. f"WHERE RoleName=%s", *args, name)
  64. if cur is None or cur.rowcount == 0:
  65. return False
  66. return True
  67. def delete_role(role_id: int):
  68. cur = db.delete("DELETE FROM role WHERE RoleID=%s", role_id)
  69. if cur is None or cur.rowcount == 0:
  70. return False
  71. return True
  72. def set_user_role(role_id: int, user_id: str):
  73. cur = db.update("UPDATE user "
  74. "SET Role=%s "
  75. "WHERE ID=%s", role_id, user_id)
  76. if cur is None or cur.rowcount == 0:
  77. return False
  78. return True
  79. def change_passwd_hash(user_id: int, passwd_hash: str):
  80. cur = db.update("UPDATE user "
  81. "SET PasswdHash=%s "
  82. "WHERE ID=%s", passwd_hash, user_id)
  83. if cur is None or cur.rowcount == 0:
  84. return False
  85. return True
  86. def get_user_email(user_id):
  87. """ 获取用户邮箱 """
  88. cur = db.search("SELECT Email FROM user WHERE ID=%s", user_id)
  89. if cur is None or cur.rowcount == 0:
  90. return None
  91. return cur.fetchone()[0]
  92. def get_role_name(role: int):
  93. """ 获取用户角色名称 """
  94. cur = db.search("SELECT RoleName FROM role WHERE RoleID=%s", role)
  95. if cur is None or cur.rowcount == 0:
  96. return None
  97. return cur.fetchone()[0]
  98. def __check_operate(operate):
  99. return operate in role_authority
  100. def check_role(role: int, operate: str):
  101. """ 检查角色权限(通过角色ID) """
  102. if not __check_operate(operate): # 检查, 防止SQL注入
  103. return False
  104. cur = db.search(f"SELECT {operate} FROM role WHERE RoleID=%s", role)
  105. if cur is None or cur.rowcount == 0:
  106. return False
  107. return cur.fetchone()[0] == DBBit.BIT_1
  108. def check_role_by_name(role: str, operate: str):
  109. """ 检查角色权限(通过角色名) """
  110. if not __check_operate(operate): # 检查, 防止SQL注入
  111. return False
  112. role = role.replace("'", "''")
  113. cur = db.search(f"SELECT {operate} FROM role WHERE RoleName=%s", role)
  114. if cur is None or cur.rowcount == 0:
  115. return False
  116. return cur.fetchone()[0] == DBBit.BIT_1
  117. def get_role_id_by_name(role: str):
  118. """ 检查角色权限(通过角色名) """
  119. role = role.replace("'", "''")
  120. cur = db.search("SELECT RoleID FROM role WHERE RoleName=%s", role)
  121. if cur is None or cur.rowcount == 0:
  122. return None
  123. return cur.fetchone()[0]
  124. def get_role_list():
  125. """ 获取归档列表 """
  126. cur = db.search("SELECT RoleID, RoleName FROM role")
  127. if cur is None or cur.rowcount == 0:
  128. return []
  129. return cur.fetchall()