1
0
Эх сурвалжийг харах

refactor: 其他SQL语句使用硬编码

SongZihuan 2 жил өмнө
parent
commit
105388b377
7 өөрчлөгдсөн 70 нэмэгдсэн , 100 устгасан
  1. 6 7
      sql/archive.py
  2. 8 12
      sql/base.py
  3. 11 9
      sql/blog.py
  4. 3 4
      sql/comment.py
  5. 3 4
      sql/msg.py
  6. 8 31
      sql/mysql.py
  7. 31 33
      sql/user.py

+ 6 - 7
sql/archive.py

@@ -6,9 +6,8 @@ def create_archive(name: str, describe: str):
     """ 创建新归档 """
     name = name.replace("'", "''")
     describe = describe.replace("'", "''")
-    cur = db.insert(table="archive",
-                    columns=["Name", "DescribeText"],
-                    values=f"'{name}', '{describe}'")
+    cur = db.insert("INSERT INTO archive(Name, DescribeText) "
+                    "VALUES (%s, %s)", name, describe)
     if cur is None or cur.rowcount == 0:
         return None
     return cur.lastrowid
@@ -35,10 +34,10 @@ def get_blog_archive(blog_id: int):
 
 
 def delete_archive(archive_id: int):
-    cur = db.delete(table="blog_archive", where=f"ArchiveID={archive_id}")
+    cur = db.delete("DELETE FROM blog_archive WHERE ArchiveID=%s", archive_id)
     if cur is None:
         return False
-    cur = db.delete(table="archive", where=f"ID={archive_id}")
+    cur = db.delete("DELETE FROM archive WHERE ID=%s", archive_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
@@ -50,14 +49,14 @@ def add_blog_to_archive(blog_id: int, archive_id: int):
         return False
     if cur.rowcount > 0:
         return True
-    cur = db.insert(table="blog_archive", columns=["BlogID", "ArchiveID"], values=f"{blog_id}, {archive_id}")
+    cur = db.insert("INSERT INTO blog_archive(BlogID, ArchiveID) VALUES (%s, %s)", blog_id, archive_id)
     if cur is None or cur.rowcount != 1:
         return False
     return True
 
 
 def sub_blog_from_archive(blog_id: int, archive_id: int):
-    cur = db.delete(table="blog_archive", where=f"BlogID={blog_id} AND ArchiveID={archive_id}")
+    cur = db.delete("DELETE FROM blog_archive WHERE BlogID=%s AND ArchiveID=%s", blog_id, archive_id)
     if cur is None:
         return False
     return True

+ 8 - 12
sql/base.py

@@ -64,41 +64,37 @@ class Database(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def search(self, sql: str, *args):
+    def search(self, sql: str, *args, not_commit: bool = False):
         """
         执行 查询 SQL语句
         :parm sql: SQL语句
+        :return:
         """
         ...
 
     @abc.abstractmethod
-    def insert(self, table: str, columns: list, values: Union[str, List[str]]):
+    def insert(self, sql: str, *args, not_commit: bool = False):
         """
         执行 插入 SQL语句, 并提交
-        :param table: 表
-        :param columns: 列名称
-        :param values: 数据
+        :parm sql: SQL语句
         :return:
         """
         ...
 
     @abc.abstractmethod
-    def delete(self, table: str, where: Union[str, List[str]] = None):
+    def delete(self, sql: str, *args, not_commit: bool = False):
         """
         执行 删除 SQL语句, 并提交
-        :param table: 表
-        :param where: 条件
+        :parm sql: SQL语句
         :return:
         """
         ...
 
     @abc.abstractmethod
-    def update(self, table: str, kw: "Dict[str:str]", where: Union[str, List[str]] = None):
+    def update(self, sql: str, *args, not_commit: bool = False):
         """
         执行 更新 SQL语句, 并提交
-        :param table: 表
-        :param kw: 键值对
-        :param where: 条件
+        :parm sql: SQL语句
         :return:
         """
         ...

+ 11 - 9
sql/blog.py

@@ -10,8 +10,8 @@ def create_blog(auth_id: int, title: str, subtitle: str, content: str,
     title = title.replace("'", "''")
     subtitle = subtitle.replace("'", "''")
     content = content.replace("'", "''")
-    cur = db.insert(table="blog", columns=["Auth", "Title", "SubTitle", "Content"],
-                    values=f"{auth_id}, '{title}', '{subtitle}', '{content}'")
+    cur = db.insert("INSERT INTO blog(Auth, Title, SubTitle, Content) "
+                    "VALUES (%s, %s, %s, %s)", auth_id, title, subtitle, content)
     if cur is None or cur.rowcount == 0:
         return False
     blog_id = cur.lastrowid
@@ -24,9 +24,9 @@ def create_blog(auth_id: int, title: str, subtitle: str, content: str,
 def update_blog(blog_id: int, content: str) -> bool:
     """ 更新博客文章 """
     content = content.replace("'", "''")
-    cur = db.update(table="blog",
-                    kw={"UpdateTime": "CURRENT_TIMESTAMP()", "Content": f"'{content}'"},
-                    where=f"ID={blog_id}")
+    cur = db.update("Update blog "
+                    "SET UpdateTime=CURRENT_TIMESTAMP(), Content=%s "
+                    "WHERE ID=%s", content, blog_id)
     if cur is None or cur.rowcount != 1:
         return False
     return True
@@ -43,20 +43,22 @@ def read_blog(blog_id: int) -> list:
 
 
 def delete_blog(blog_id: int):
-    cur = db.delete(table="blog_archive", where=f"BlogID={blog_id}")
+    cur = db.delete("DELETE FROM blog_archive WHERE BlogID=%s", blog_id)
     if cur is None:
         return False
-    cur = db.delete(table="comment", where=f"BlogID={blog_id}")
+    cur = db.delete("DELETE FROM comment WHERE BlogID=%s", blog_id)
     if cur is None:
         return False
-    cur = db.delete(table="blog", where=f"ID={blog_id}")
+    cur = db.delete("DELETE FROM blog WHERE ID=%s", blog_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
 def set_blog_top(blog_id: int, top: bool = True):
-    cur = db.update(table="blog", kw={"Top": "1" if top else "0"}, where=f"ID={blog_id}")
+    cur = db.update("UPDATE blog "
+                    "SET Top=%s "
+                    "WHERE ID=%s", 1 if top else 0, blog_id)
     if cur is None or cur.rowcount != 1:
         return False
     return True

+ 3 - 4
sql/comment.py

@@ -15,9 +15,8 @@ def read_comment_list(blog_id: int):
 def create_comment(blog_id: int, user_id: int, content: str):
     """ 新建 comment """
     content = content.replace("'", "''")
-    cur = db.insert(table="comment",
-                    columns=["BlogID", "Auth", "Content"],
-                    values=f"{blog_id}, {user_id}, '{content}'")
+    cur = db.insert("INSERT INTO comment(BlogID, Auth, Content) "
+                    "VALUES (%s, %s, %s)", blog_id, user_id, content)
     if cur is None or cur.rowcount == 0:
         return False
     return True
@@ -33,7 +32,7 @@ def read_comment(comment_id: int):
 
 def delete_comment(comment_id):
     """ 删除评论 """
-    cur = db.delete(table="comment", where=f"ID={comment_id}")
+    cur = db.delete("DELETE FROM comment WHERE ID=%s", comment_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True

+ 3 - 4
sql/msg.py

@@ -34,9 +34,8 @@ def read_msg_list(limit: Optional[int] = None, offset: Optional[int] = None, sho
 
 def create_msg(auth: int, content: str, secret: bool = False):
     content = content.replace("'", "''")
-    cur = db.insert(table="message",
-                    columns=["Auth", "Content", "Secret"],
-                    values=f"{auth}, '{content}', {1 if secret else 0}")
+    cur = db.insert("INSERT INTO message(Auth, Content, Secret) "
+                    "VALUES (%s, %s, %s)", auth, content, 1 if secret else 0)
     if cur is None or cur.rowcount != 1:
         return None
     return cur.lastrowid
@@ -52,7 +51,7 @@ def read_msg(msg_id: int):
 
 
 def delete_msg(msg_id: int):
-    cur = db.delete(table="message", where=f"ID={msg_id}")
+    cur = db.delete("DELETE FROM message WHERE ID=%s", msg_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True

+ 8 - 31
sql/mysql.py

@@ -57,37 +57,14 @@ class MysqlDB(Database):
     def search(self, sql: str, *args) -> Union[None, pymysql.cursors.Cursor]:
         return self.__search(sql, args)
 
-    def insert(self, table: str, columns: list, values: Union[str, List[str]],
-               not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
-        columns: str = ", ".join(columns)
-        if type(values) is str:
-            values: str = f"({values})"
-        else:
-            values: str = ", ".join(f"{v}" for v in values)
-        return self.__done(f"INSERT INTO {table}({columns}) VALUES {values};", not_commit=not_commit)
-
-    def delete(self, table: str, where: Union[str, List[str]] = None,
-               not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
-        if type(where) is list and len(where) > 0:
-            where: str = " AND ".join(f"({w})" for w in where)
-        elif type(where) is not str or len(where) == 0:  # 必须指定条件
-            return None
-
-        return self.__done(f"DELETE FROM {table} WHERE {where};", not_commit=not_commit)
+    def insert(self, sql: str, *args, not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
+        return self.__done(sql, args, not_commit=not_commit)
 
-    def update(self, table: str, kw: "Dict[str:str]", where: Union[str, List[str]] = None,
-               not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
-        if len(kw) == 0:
-            return None
+    def delete(self, sql: str, *args, not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
+        return self.__done(sql, args, not_commit=not_commit)
 
-        if type(where) is list and len(where) > 0:
-            where: str = " AND ".join(f"({w})" for w in where)
-        elif type(where) is not str or len(where) == 0:  # 必须指定条件
-            return None
-
-        kw_list = [f"{key} = {kw[key]}" for key in kw]
-        kw_str = ", ".join(kw_list)
-        return self.__done(f"UPDATE {table} SET {kw_str} WHERE {where};", not_commit=not_commit)
+    def update(self, sql: str, *args, not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
+        return self.__done(sql, args, not_commit=not_commit)
 
     def commit(self):
         self._commit()
@@ -107,13 +84,13 @@ class MysqlDB(Database):
             self._lock.release()  # 释放锁
         return self._cursor
 
-    def __done(self, sql, not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
+    def __done(self, sql, args, not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
         try:
             self._lock.acquire()
             if not self.is_connect():
                 self.logger.error(f"MySQL({self._name}@{self._host}) connect error")
                 return
-            self._cursor.execute(sql)
+            self._cursor.execute(query=sql, args=args)
         except pymysql.MySQLError:
             self._db.rollback()
             self.logger.error(f"MySQL({self._name}@{self._host}) SQL {sql} error {inspect.stack()[2][2]} "

+ 31 - 33
sql/user.py

@@ -28,9 +28,11 @@ def create_user(email: str, passwd: str):
     passwd = object.user.User.get_passwd_hash(passwd)
     if cur is None or cur.rowcount == 0 or cur.fetchone()[0] == 0:
         # 创建为管理员用户
-        cur = db.insert(table='user', columns=['Email', 'PasswdHash', 'Role'], values=f"'{email}', '{passwd}', 1")
+        cur = db.insert("INSERT INTO user(Email, PasswdHash, Role) "
+                        "VALUES (%s, %s, %s)", email, passwd, 1)
     else:
-        cur = db.insert(table='user', columns=['Email', 'PasswdHash'], values=f"'{email}', '{passwd}'")
+        cur = db.insert("INSERT INTO user(Email, PasswdHash) "
+                        "VALUES (%s, %s)", email, passwd)
     if cur is None or cur.rowcount != 1:
         return None
     return cur.lastrowid
@@ -38,56 +40,66 @@ def create_user(email: str, passwd: str):
 
 def delete_user(user_id: int):
     """ 删除用户 """
-    cur = db.delete(table="message", where=f"Auth={user_id}")
+    cur = db.delete("DELETE FROM message WHERE Auth=%s", user_id)
     if cur is None:
         return False
-    cur = db.delete(table="comment", where=f"Auth={user_id}")
+    cur = db.delete("DELETE FROM comment WHERE Auth=%s", user_id)
     if cur is None:
         return False
-    cur = db.delete(table="blog", where=f"Auth={user_id}")
+    cur = db.delete("DELETE FROM blog WHERE Auth=%s", user_id)
     if cur is None:
         return False
-    cur = db.delete(table="user", where=f"ID={user_id}")
+    cur = db.delete("DELETE FROM user WHERE ID=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
+def __authority_to_sql(authority):
+    """ authority 转换为 Update语句, 不检查合法性 """
+    sql = []
+    args = []
+    for i in authority:
+        sql.append(f"{i}=%s")
+        args.append(authority[i])
+    return ",".join(sql), args
+
+
 def create_role(name: str, authority: List[str]):
     name = name.replace("'", "''")
-    cur = db.insert(table="role", columns=["RoleName"], values=f"'{name}'", not_commit=True)
+    cur = db.insert("INSERT INTO role(RoleName) VALUES (%s)", name)
     if cur is None or cur.rowcount == 0:
         return False
 
-    kw = {}
-    for i in role_authority:
-        kw[i] = '0'
-    for i in authority:
-        if i in role_authority:
-            kw[i] = '1'
-
-    cur = db.update(table='role', kw=kw, where=f"RoleName='{name}'")
+    sql, args = __authority_to_sql({i: (1 if i in authority else 0) for i in role_authority})
+    cur = db.update(f"UPDATE role "
+                    f"SET {sql} "
+                    f"WHERE RoleName=%s", *args, name)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
 def delete_role(role_id: int):
-    cur = db.delete(table="role", where=f"RoleID={role_id}")
+    cur = db.delete("DELETE FROM role WHERE RoleID=%s", role_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
 def set_user_role(role_id: int, user_id: str):
-    cur = db.update(table="user", kw={"Role": f"{role_id}"}, where=f"ID={user_id}")
+    cur = db.update("UPDATE user "
+                    "SET Role=%s "
+                    "WHERE ID=%s", role_id, user_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
 def change_passwd_hash(user_id: int, passwd_hash: str):
-    cur = db.update(table='user', kw={'PasswdHash': f"'{passwd_hash}'"}, where=f'ID={user_id}')
+    cur = db.update("UPDATE user "
+                    "SET PasswdHash=%s "
+                    "WHERE ID=%s", passwd_hash, user_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
@@ -110,21 +122,7 @@ def get_role_name(role: int):
 
 
 def __check_operate(operate):
-    return operate in ["WriteBlog",
-                       "WriteComment",
-                       "WriteMsg",
-                       "CreateUser",
-                       "ReadBlog",
-                       "ReadComment",
-                       "ReadMsg",
-                       "ReadSecretMsg",
-                       "ReadUserInfo",
-                       "DeleteBlog",
-                       "DeleteComment",
-                       "DeleteMsg",
-                       "DeleteUser",
-                       "ConfigureSystem",
-                       "ReadSystem"]
+    return operate in role_authority
 
 
 def check_role(role: int, operate: str):