Browse Source

feat: SQL转义

SongZihuan 3 years ago
parent
commit
d34d5e4c5c
5 changed files with 12 additions and 0 deletions
  1. 2 0
      sql/archive.py
  2. 4 0
      sql/blog.py
  3. 1 0
      sql/comment.py
  4. 1 0
      sql/msg.py
  5. 4 0
      sql/user.py

+ 2 - 0
sql/archive.py

@@ -4,6 +4,8 @@ from typing import Optional
 
 def create_archive(name: str, describe: str):
     """ 创建新归档 """
+    name = name.replace("'", "''")
+    describe = describe.replace("'", "''")
     cur = db.insert(table="archive",
                     columns=["Name", "DescribeText"],
                     values=f"'{name}', '{describe}'")

+ 4 - 0
sql/blog.py

@@ -6,6 +6,9 @@ import object.archive
 def create_blog(auth_id: int, title: str, subtitle: str, context: str,
                 archive_list: List[object.archive.Archive]) -> bool:
     """ 写入新的blog """
+    title = title.replace("'", "''")
+    subtitle = subtitle.replace("'", "''")
+    context = context.replace("'", "''")
     cur = db.insert(table="blog", columns=["Auth", "Title", "SubTitle", "Context"],
                     values=f"{auth_id}, '{title}', '{subtitle}', '{context}'")
     if cur is None or cur.rowcount == 0:
@@ -21,6 +24,7 @@ def create_blog(auth_id: int, title: str, subtitle: str, context: str,
 
 def update_blog(blog_id: int, context: str) -> bool:
     """ 更新博客文章 """
+    context = context.replace("'", "''")
     cur = db.update(table="blog",
                     kw={"UpdateTime": "CURRENT_TIMESTAMP()", "Context": f"'{context}'"},
                     where=f"ID={blog_id}")

+ 1 - 0
sql/comment.py

@@ -13,6 +13,7 @@ def read_comment(blog_id: int):
 
 def create_comment(blog_id: int, user_id: int, context: str):
     """ 新建 comment """
+    context = context.replace("'", "''")
     cur = db.insert(table="comment",
                     columns=["BlogID", "Auth", "Context"],
                     values=f"{blog_id}, {user_id}, '{context}'")

+ 1 - 0
sql/msg.py

@@ -3,6 +3,7 @@ from typing import Optional
 
 
 def create_msg(auth: int, context: str, secret: bool = False):
+    context = context.replace("'", "''")
     cur = db.insert(table="message",
                     columns=["Auth", "Context", "Secret"],
                     values=f"{auth}, '{context}', {1 if secret else 0}")

+ 4 - 0
sql/user.py

@@ -22,6 +22,7 @@ def read_user(email: str):
 
 def create_user(email: str, passwd: str):
     """ 创建用户 """
+    email = email.replace("'", "''")
     cur = db.search(columns=["count(Email)"], table="user")  # 统计个数
     passwd = object.user.User.get_passwd_hash(passwd)
     if cur is None or cur.rowcount == 0 or cur.fetchone()[0] == 0:
@@ -48,6 +49,7 @@ def delete_user(user_id: int):
 
 
 def create_role(name: str, authority: List[str]):
+    name = name.replace("'", "''")
     cur = db.insert(table="role", columns=["RoleName"], values=f"'{name}'", not_commit=True)
     if cur is None or cur.rowcount == 0:
         return False
@@ -112,6 +114,7 @@ def check_role(role: int, operate: str):
 
 def check_role_by_name(role: str, operate: str):
     """ 检查角色权限(通过角色名) """
+    role = role.replace("'", "''")
     cur = db.search(columns=[operate], table="role", where=f"RoleName='{role}'")
     if cur is None or cur.rowcount == 0:
         return False
@@ -120,6 +123,7 @@ def check_role_by_name(role: str, operate: str):
 
 def get_role_id_by_name(role: str):
     """ 检查角色权限(通过角色名) """
+    role = role.replace("'", "''")
     cur = db.search(columns=["RoleID"], table="role", where=f"RoleName='{role}'")
     if cur is None or cur.rowcount == 0:
         return None