Sfoglia il codice sorgente

refactor: SQL的SELECT语句使用硬编码

SongZihuan 2 anni fa
parent
commit
3c70630e4b
7 ha cambiato i file con 125 aggiunte e 125 eliminazioni
  1. 18 12
      sql/archive.py
  2. 2 12
      sql/base.py
  3. 35 22
      sql/blog.py
  4. 6 9
      sql/comment.py
  5. 27 13
      sql/msg.py
  6. 7 48
      sql/mysql.py
  7. 30 9
      sql/user.py

+ 18 - 12
sql/archive.py

@@ -16,9 +16,9 @@ def create_archive(name: str, describe: str):
 
 def read_archive(archive_id: int):
     """ 获取归档 ID """
-    cur = db.search(columns=["Name", "DescribeText"],
-                    table="archive",
-                    where=f"ID={archive_id}")
+    cur = db.search("SELECT Name, DescribeText "
+                    "FROM archive "
+                    "WHERE ID=%s", archive_id)
     if cur is None or cur.rowcount == 0:
         return ["", ""]
     return cur.fetchone()
@@ -26,10 +26,9 @@ def read_archive(archive_id: int):
 
 def get_blog_archive(blog_id: int):
     """ 获取文章的归档 """
-    cur = db.search(columns=["ArchiveID"],
-                    table="blog_archive_with_name",
-                    where=f"BlogID={blog_id}",
-                    order_by=[("ArchiveName", "ASC")])
+    cur = db.search("SELECT ArchiveID FROM blog_archive_with_name "
+                    "WHERE BlogID=%s "
+                    "ORDER BY ArchiveName", blog_id)
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
@@ -46,7 +45,7 @@ def delete_archive(archive_id: int):
 
 
 def add_blog_to_archive(blog_id: int, archive_id: int):
-    cur = db.search(columns=["BlogID"], table="blog_archive", where=f"BlogID={blog_id} AND ArchiveID={archive_id}")
+    cur = db.search("SELECT BlogID FROM blog_archive WHERE BlogID=%s AND ArchiveID=%s", blog_id, archive_id)
     if cur is None:
         return False
     if cur.rowcount > 0:
@@ -66,10 +65,17 @@ def sub_blog_from_archive(blog_id: int, archive_id: int):
 
 def get_archive_list(limit: Optional[int] = None, offset: Optional[int] = None):
     """ 获取归档列表 """
-    cur = db.search(columns=["ID", "Name", "DescribeText", "Count"], table="archive_with_count",
-                    limit=limit,
-                    offset=offset,
-                    order_by=[("Count", "DESC"), ("Name", "ASC")])
+    if limit is not None and offset is not None:
+        cur = db.search("SELECT ID, Name, DescribeText, Count "
+                        "FROM archive_with_count "
+                        "ORDER BY Count DESC , Name "
+                        "LIMIT %s "
+                        "OFFSET %s ", limit, offset)
+    else:
+        cur = db.search("SELECT ID, Name, DescribeText, Count "
+                        "FROM archive_with_count "
+                        "ORDER BY Count DESC , Name")
+
     if cur is None or cur.rowcount == 0:
         return []
     return cur.fetchall()

+ 2 - 12
sql/base.py

@@ -64,20 +64,10 @@ class Database(metaclass=abc.ABCMeta):
         ...
 
     @abc.abstractmethod
-    def search(self, columns: List[str], table: str,
-               where: Union[str, List[str]] = None,
-               limit: Optional[int] = None,
-               offset: Optional[int] = None,
-               order_by: Optional[List[Tuple[str, str]]] = None):
+    def search(self, sql: str, *args):
         """
         执行 查询 SQL语句
-        :param columns: 列名称
-        :param table: 表
-        :param where: 条件
-        :param limit: 限制行数
-        :param offset: 偏移
-        :param order_by: 排序方式
-        :return:
+        :parm sql: SQL语句
         """
         ...
 

+ 35 - 22
sql/blog.py

@@ -34,9 +34,9 @@ def update_blog(blog_id: int, content: str) -> bool:
 
 def read_blog(blog_id: int) -> list:
     """ 读取blog内容 """
-    cur = db.search(columns=["Auth", "Title", "SubTitle", "Content", "UpdateTime", "CreateTime", "Top"],
-                    table="blog",
-                    where=f"ID={blog_id}")
+    cur = db.search("SELECT Auth, Title, SubTitle, Content, UpdateTime, CreateTime, Top "
+                    "FROM blog "
+                    "WHERE ID=%s", blog_id)
     if cur is None or cur.rowcount == 0:
         return [-1, "", "", "", 0, -1, False]
     return cur.fetchone()
@@ -64,10 +64,15 @@ def set_blog_top(blog_id: int, top: bool = True):
 
 def get_blog_list(limit: Optional[int] = None, offset: Optional[int] = None) -> list:
     """ 获得 blog 列表 """
-    cur = db.search(columns=["ID", "Title", "SubTitle", "UpdateTime", "CreateTime", "Top"], table="blog_with_top",
-                    order_by=[("Top", "DESC"), ("CreateTime", "DESC"), ("Title", "ASC"), ("SubTitle", "ASC")],
-                    limit=limit,
-                    offset=offset)
+    if limit is not None and offset is not None:
+        cur = db.search("SELECT ID, Title, SubTitle, UpdateTime, CreateTime, Top "
+                        "FROM blog_with_top "  # TODO: 去除blog_with_top
+                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle "
+                        "LIMIT %s OFFSET %s", limit, offset)
+    else:
+        cur = db.search("SELECT ID, Title, SubTitle, UpdateTime, CreateTime, Top "
+                        "FROM blog_with_top "  # TODO: 去除blog_with_top
+                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle")
     if cur is None or cur.rowcount == 0:
         return []
     return cur.fetchall()
@@ -75,10 +80,15 @@ def get_blog_list(limit: Optional[int] = None, offset: Optional[int] = None) ->
 
 def get_blog_list_not_top(limit: Optional[int] = None, offset: Optional[int] = None) -> list:
     """ 获得blog列表 忽略置顶 """
-    cur = db.search(columns=["ID", "Title", "SubTitle", "UpdateTime", "CreateTime"], table="blog",
-                    order_by=[("CreateTime", "DESC"), ("Title", "ASC"), ("SubTitle", "ASC")],
-                    limit=limit,
-                    offset=offset)
+    if limit is not None and offset is not None:
+        cur = db.search("SELECT ID, Title, SubTitle, UpdateTime, CreateTime "
+                        "FROM blog "
+                        "ORDER BY CreateTime DESC, Title, SubTitle "
+                        "LIMIT %s OFFSET %s", limit, offset)
+    else:
+        cur = db.search("SELECT ID, Title, SubTitle, UpdateTime, CreateTime "
+                        "FROM blog "
+                        "ORDER BY CreateTime DESC, Title, SubTitle")
     if cur is None or cur.rowcount == 0:
         return []
     return cur.fetchall()
@@ -86,7 +96,7 @@ def get_blog_list_not_top(limit: Optional[int] = None, offset: Optional[int] = N
 
 def get_blog_count() -> int:
     """ 统计 blog 个数 """
-    cur = db.search(columns=["count(ID)"], table="blog")
+    cur = db.search("SELECT COUNT(*) FROM blog")
     if cur is None or cur.rowcount == 0:
         return 0
     return cur.fetchone()[0]
@@ -94,12 +104,17 @@ def get_blog_count() -> int:
 
 def get_archive_blog_list(archive_id, limit: Optional[int] = None, offset: Optional[int] = None) -> list:
     """ 获得指定归档的 blog 列表 """
-    cur = db.search(columns=["BlogID", "Title", "SubTitle", "UpdateTime", "CreateTime", "Top"],
-                    table="blog_with_archive",
-                    order_by=[("Top", "DESC"), ("CreateTime", "DESC"), ("Title", "ASC"), ("SubTitle", "ASC")],
-                    where=f"ArchiveID={archive_id}",
-                    limit=limit,
-                    offset=offset)
+    if limit is not None and offset is not None:
+        cur = db.search("SELECT BlogID, Title, SubTitle, UpdateTime, CreateTime, Top "
+                        "FROM blog_with_archive "
+                        "WHERE ArchiveID=%s "
+                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle "
+                        "LIMIT %s OFFSET %s", archive_id, limit, offset)
+    else:
+        cur = db.search("SELECT BlogID, Title, SubTitle, UpdateTime, CreateTime, Top "
+                        "FROM blog_with_archive "
+                        "WHERE ArchiveID=%s "
+                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle")
     if cur is None or cur.rowcount == 0:
         return []
     return cur.fetchall()
@@ -107,8 +122,7 @@ def get_archive_blog_list(archive_id, limit: Optional[int] = None, offset: Optio
 
 def get_archive_blog_count(archive_id) -> int:
     """ 统计指定归档的 blog 个数 """
-    cur = db.search(columns=["count(BlogID)"], table="blog_with_archive",
-                    where=f"ArchiveID={archive_id}")
+    cur = db.search("SELECT COUNT(*) FROM blog_with_archive WHERE ArchiveID=%s", archive_id)
     if cur is None or cur.rowcount == 0:
         return 0
     return cur.fetchone()[0]
@@ -116,8 +130,7 @@ def get_archive_blog_count(archive_id) -> int:
 
 def get_user_user_count(user_id: int) -> int:
     """ 获得指定用户的 blog 个数 """
-    cur = db.search(columns=["count(ID)"], table="blog",
-                    where=f"Auth={user_id}")
+    cur = db.search("SELECT COUNT(*) FROM blog WHERE Auth=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return 0
     return cur.fetchone()[0]

+ 6 - 9
sql/comment.py

@@ -3,10 +3,10 @@ from sql import db
 
 def read_comment_list(blog_id: int):
     """ 读取文章的 comment """
-    cur = db.search(columns=["CommentID"],
-                    table="comment_user",
-                    where=f"BlogID={blog_id}",
-                    order_by=[("UpdateTime", "DESC")])
+    cur = db.search("SELECT CommentID "
+                    "FROM comment_user "
+                    "WHERE BlogID=%s "
+                    "ORDER BY UpdateTime DESC", blog_id)
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
@@ -25,9 +25,7 @@ def create_comment(blog_id: int, user_id: int, content: str):
 
 def read_comment(comment_id: int):
     """ 读取 comment """
-    cur = db.search(columns=["BlogID", "Email", "Content", "UpdateTime"],
-                    table="comment_user",
-                    where=f"CommentID={comment_id}")
+    cur = db.search("SELECT BlogID, Email, Content, UpdateTime FROM comment_user WHERE CommentID=%s", comment_id)
     if cur is None or cur.rowcount == 0:
         return [-1, "", "", 0]
     return cur.fetchone()
@@ -43,8 +41,7 @@ def delete_comment(comment_id):
 
 def get_user_comment_count(user_id: int):
     """ 读取指定用户的 comment 个数 """
-    cur = db.search(columns=["count(ID)"], table="comment",
-                    where=f"Auth={user_id}")
+    cur = db.search("SELECT COUNT(*) FROM comment WHERE Auth=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return 0
     return cur.fetchone()[0]

+ 27 - 13
sql/msg.py

@@ -4,15 +4,29 @@ from typing import Optional
 
 def read_msg_list(limit: Optional[int] = None, offset: Optional[int] = None, show_secret: bool = False):
     if show_secret:
-        where = None
+        if limit is not None and offset is not None:
+            cur = db.search("SELECT MsgID "
+                            "FROM message_user "
+                            "ORDER BY UpdateTime DESC "
+                            "LIMIT %s "
+                            "OFFSET %s", limit, offset)
+        else:
+            cur = db.search("SELECT MsgID "
+                            "FROM message_user "
+                            "ORDER BY UpdateTime DESC")
     else:
-        where = "Secret=0"
-
-    cur = db.search(columns=["MsgID"], table="message_user",
-                    limit=limit,
-                    where=where,
-                    offset=offset,
-                    order_by=[("UpdateTime", "DESC")])
+        if limit is not None and offset is not None:
+            cur = db.search("SELECT MsgID "
+                            "FROM message_user "
+                            "WHERE Secret=0 "
+                            "ORDER BY UpdateTime DESC "
+                            "LIMIT %s "
+                            "OFFSET %s", limit, offset)
+        else:
+            cur = db.search("SELECT MsgID "
+                            "FROM message_user "
+                            "WHERE Secret=0 "
+                            "ORDER BY UpdateTime DESC")
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
@@ -29,8 +43,9 @@ def create_msg(auth: int, content: str, secret: bool = False):
 
 
 def read_msg(msg_id: int):
-    cur = db.search(columns=["Email", "Content", "UpdateTime", "Secret"], table="message_user",
-                    where=f"MsgID={msg_id}")
+    cur = db.search("SELECT Email, Content, UpdateTime, Secret "
+                    "FROM message_user "
+                    "WHERE MsgID=%s", msg_id)
     if cur is None or cur.rowcount == 0:
         return ["", "", 0, False]
     return cur.fetchone()
@@ -44,15 +59,14 @@ def delete_msg(msg_id: int):
 
 
 def get_msg_count():
-    cur = db.search(columns=["count(ID)"], table="message")
+    cur = db.search("SELECT COUNT(*) FROM message")
     if cur is None or cur.rowcount == 0:
         return 0
     return cur.fetchone()[0]
 
 
 def get_user_msg_count(user_id: int):
-    cur = db.search(columns=["count(ID)"], table="message",
-                    where=f"Auth={user_id}")
+    cur = db.search("SELECT COUNT(*) FROM message WHERE Auth=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return 0
     return cur.fetchone()[0]

+ 7 - 48
sql/mysql.py

@@ -54,49 +54,8 @@ class MysqlDB(Database):
             raise DBCloseException
         return self._cursor
 
-    def search(self, columns: List[str], table: str,
-               where: Union[str, List[str]] = None,
-               limit: Optional[int] = None,
-               offset: Optional[int] = None,
-               order_by: Optional[List[Tuple[str, str]]] = None,
-               group_by: Optional[List[str]] = None,
-               for_update: bool = False) -> Union[None, pymysql.cursors.Cursor]:
-        if type(where) is list and len(where) > 0:
-            where: str = " WHERE " + " AND ".join(f"({w})" for w in where)
-        elif type(where) is str and len(where) > 0:
-            where = " WHERE " + where
-        else:
-            where: str = ""
-
-        if order_by is None:
-            order_by: str = ""
-        else:
-            by = [f" {i[0]} {i[1]} " for i in order_by]
-            order_by: str = " ORDER BY" + ", ".join(by)
-
-        if limit is None or limit == 0:
-            limit: str = ""
-        else:
-            limit = f" LIMIT {limit}"
-
-        if offset is None:
-            offset: str = ""
-        else:
-            offset = f" OFFSET {offset}"
-
-        if group_by is None:
-            group_by: str = ""
-        else:
-            group_by = "GROUP BY " + ", ".join(group_by)
-
-        columns: str = ", ".join(columns)
-        if for_update:
-            for_update = "FOR UPDATE"
-        else:
-            for_update = ""
-        return self.__search(f"SELECT {columns} "
-                             f"FROM {table} "
-                             f"{where} {group_by} {order_by} {limit} {offset} {for_update};")
+    def search(self, sql: str, *args) -> Union[None, pymysql.cursors.Cursor]:
+        return self.__search(sql, args)
 
     def insert(self, table: str, columns: list, values: Union[str, List[str]],
                not_commit: bool = False) -> Union[None, pymysql.cursors.Cursor]:
@@ -133,15 +92,15 @@ class MysqlDB(Database):
     def commit(self):
         self._commit()
 
-    def __search(self, sql) -> Union[None, pymysql.cursors.Cursor]:
+    def __search(self, sql, args) -> Union[None, pymysql.cursors.Cursor]:
         try:
             self._lock.acquire()  # 上锁
             if not self.is_connect():
-                self.logger.error(f"MySQL({self._name}@{self._host}) SQL {sql} connect error")
+                self.logger.error(f"MySQL({self._name}@{self._host}) connect error")
                 return
-            self._cursor.execute(sql)
+            self._cursor.execute(query=sql, args=args)
         except pymysql.MySQLError:
-            self.logger.error(f"MySQL({self._name}@{self._host}) SQL {sql} error {inspect.stack()[2][2]} "
+            self.logger.error(f"MySQL({self._name}@{self._host}) SQL {sql} with {args} error {inspect.stack()[2][2]} "
                               f"{inspect.stack()[2][1]} {inspect.stack()[2][3]}", exc_info=True, stack_info=True)
             return
         finally:
@@ -152,7 +111,7 @@ class MysqlDB(Database):
         try:
             self._lock.acquire()
             if not self.is_connect():
-                self.logger.error(f"MySQL({self._name}@{self._host}) SQL {sql} connect error")
+                self.logger.error(f"MySQL({self._name}@{self._host}) connect error")
                 return
             self._cursor.execute(sql)
         except pymysql.MySQLError:

+ 30 - 9
sql/user.py

@@ -4,7 +4,6 @@ import object.user
 
 from typing import List
 
-
 role_authority = ["WriteBlog", "WriteComment", "WriteMsg", "CreateUser",
                   "ReadBlog", "ReadComment", "ReadMsg", "ReadSecretMsg", "ReadUserInfo",
                   "DeleteBlog", "DeleteComment", "DeleteMsg", "DeleteUser",
@@ -13,7 +12,7 @@ role_authority = ["WriteBlog", "WriteComment", "WriteMsg", "CreateUser",
 
 def read_user(email: str):
     """ 读取用户 """
-    cur = db.search(columns=["PasswdHash", "Role", "ID"], table="user", where=f"Email='{email}'")
+    cur = db.search("SELECT PasswdHash, Role, ID FROM user WHERE Email=%s", email)
     if cur is None or cur.rowcount != 1:
         return ["", -1, -1]
     return cur.fetchone()
@@ -25,7 +24,7 @@ def create_user(email: str, passwd: str):
     if len(email) == 0:
         return None
 
-    cur = db.search(columns=["count(Email)"], table="user")  # 统计个数
+    cur = db.search("SELECT COUNT(*) FROM user")
     passwd = object.user.User.get_passwd_hash(passwd)
     if cur is None or cur.rowcount == 0 or cur.fetchone()[0] == 0:
         # 创建为管理员用户
@@ -96,7 +95,7 @@ def change_passwd_hash(user_id: int, passwd_hash: str):
 
 def get_user_email(user_id):
     """ 获取用户邮箱 """
-    cur = db.search(columns=["Email"], table="user", where=f"ID='{user_id}'")
+    cur = db.search("SELECT Email FROM user WHERE ID=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return None
     return cur.fetchone()[0]
@@ -104,15 +103,35 @@ def get_user_email(user_id):
 
 def get_role_name(role: int):
     """ 获取用户角色名称 """
-    cur = db.search(columns=["RoleName"], table="role", where=f"RoleID={role}")
+    cur = db.search("SELECT RoleName FROM role WHERE RoleID=%s", role)
     if cur is None or cur.rowcount == 0:
         return None
     return cur.fetchone()[0]
 
 
+def __check_operate(operate):
+    return operate in ["WriteBlog",
+                       "WriteComment",
+                       "WriteMsg",
+                       "CreateUser",
+                       "ReadBlog",
+                       "ReadComment",
+                       "ReadMsg",
+                       "ReadSecretMsg",
+                       "ReadUserInfo",
+                       "DeleteBlog",
+                       "DeleteComment",
+                       "DeleteMsg",
+                       "DeleteUser",
+                       "ConfigureSystem",
+                       "ReadSystem"]
+
+
 def check_role(role: int, operate: str):
     """ 检查角色权限(通过角色ID) """
-    cur = db.search(columns=[operate], table="role", where=f"RoleID={role}")
+    if not __check_operate(operate):  # 检查, 防止SQL注入
+        return False
+    cur = db.search(f"SELECT {operate} FROM role WHERE RoleID=%s", role)
     if cur is None or cur.rowcount == 0:
         return False
     return cur.fetchone()[0] == DBBit.BIT_1
@@ -120,8 +139,10 @@ def check_role(role: int, operate: str):
 
 def check_role_by_name(role: str, operate: str):
     """ 检查角色权限(通过角色名) """
+    if not __check_operate(operate):  # 检查, 防止SQL注入
+        return False
     role = role.replace("'", "''")
-    cur = db.search(columns=[operate], table="role", where=f"RoleName='{role}'")
+    cur = db.search(f"SELECT {operate} FROM role WHERE RoleName=%s", role)
     if cur is None or cur.rowcount == 0:
         return False
     return cur.fetchone()[0] == DBBit.BIT_1
@@ -130,7 +151,7 @@ def check_role_by_name(role: str, operate: str):
 def get_role_id_by_name(role: str):
     """ 检查角色权限(通过角色名) """
     role = role.replace("'", "''")
-    cur = db.search(columns=["RoleID"], table="role", where=f"RoleName='{role}'")
+    cur = db.search("SELECT RoleID FROM role WHERE RoleName=%s", role)
     if cur is None or cur.rowcount == 0:
         return None
     return cur.fetchone()[0]
@@ -138,7 +159,7 @@ def get_role_id_by_name(role: str):
 
 def get_role_list():
     """ 获取归档列表 """
-    cur = db.search(columns=["RoleID", "RoleName"], table="role")
+    cur = db.search("SELECT RoleID, RoleName FROM role")
     if cur is None or cur.rowcount == 0:
         return []
     return cur.fetchall()