Ver Fonte

feat: 使用事务

SongZihuan há 2 anos atrás
pai
commit
add1ff7461
4 ficheiros alterados com 100 adições e 27 exclusões
  1. 12 2
      sql/archive.py
  2. 15 3
      sql/blog.py
  3. 44 16
      sql/mysql.py
  4. 29 6
      sql/user.py

+ 12 - 2
sql/archive.py

@@ -54,12 +54,22 @@ def get_blog_archive(blog_id: int, mysql: DB = db, not_cache=False):
 def delete_archive(archive_id: int, mysql: DB = db):
     delete_archive_from_cache(archive_id)
     delete_all_blog_archive_from_cache()
-    cur = mysql.delete("DELETE FROM blog_archive WHERE ArchiveID=%s", archive_id)
+    conn = mysql.get_connection()
+
+    cur = mysql.delete("DELETE FROM blog_archive WHERE ArchiveID=%s", archive_id, connection=conn)
     if cur is None:
+        conn.rollback()
+        conn.close()
         return False
-    cur = mysql.delete("DELETE FROM archive WHERE ID=%s", archive_id)
+
+    cur = mysql.delete("DELETE FROM archive WHERE ID=%s", archive_id, connection=conn)
     if cur is None or cur.rowcount == 0:
+        conn.rollback()
+        conn.close()
         return False
+
+    conn.commit()
+    conn.close()
     return True
 
 

+ 15 - 3
sql/blog.py

@@ -71,15 +71,27 @@ def delete_blog(blog_id: int, mysql: DB = db):
     delete_blog_from_cache(blog_id)
     delete_blog_archive_from_cache(blog_id)
 
-    cur = mysql.delete("DELETE FROM blog_archive WHERE BlogID=%s", blog_id)
+    conn = mysql.get_connection()
+    cur = mysql.delete("DELETE FROM blog_archive WHERE BlogID=%s", blog_id, connection=conn)
     if cur is None:
+        conn.rollback()
+        conn.close()
         return False
-    cur = mysql.delete("DELETE FROM comment WHERE BlogID=%s", blog_id)
+
+    cur = mysql.delete("DELETE FROM comment WHERE BlogID=%s", blog_id, connection=conn)
     if cur is None:
+        conn.rollback()
+        conn.close()
         return False
-    cur = mysql.delete("DELETE FROM blog WHERE ID=%s", blog_id)
+
+    cur = mysql.delete("DELETE FROM blog WHERE ID=%s", blog_id, connection=conn)
     if cur is None or cur.rowcount == 0:
+        conn.rollback()
+        conn.close()
         return False
+
+    conn.commit()
+    conn.close()
     return True
 
 

+ 44 - 16
sql/mysql.py

@@ -1,7 +1,6 @@
-import pymysql.cursors
 import pymysql
 from dbutils.pooled_db import PooledDB
-import threading
+from dbutils.steady_db import SteadyDBCursor
 from sql.base import Database, DBException, DBCloseException
 from typing import Optional, Union
 import inspect
@@ -13,7 +12,7 @@ class MysqlConnectException(DBCloseException):
 
 class MysqlDB(Database):
     class Result:
-        def __init__(self, cur: pymysql.cursors):
+        def __init__(self, cur: SteadyDBCursor):
             self.res: list = cur.fetchall()
             self.lastrowid: int = cur.lastrowid
             self.rowcount: int = cur.rowcount
@@ -27,6 +26,25 @@ class MysqlDB(Database):
         def __iter__(self):
             return self.res.__iter__()
 
+    class Connection:
+        def __init__(self, conn):
+            self.conn = conn
+            self.cur = conn.cursor()
+
+        def get_cursor(self):
+            return self.cur
+
+        def commit(self):
+            self.conn.commit()
+
+        def rollback(self):
+            self.conn.rollback()
+
+        def close(self):
+            self.cur.close()
+            self.conn.close()
+
+
     def __init__(self,
                  host: Optional[str],
                  name: Optional[str],
@@ -52,17 +70,20 @@ class MysqlDB(Database):
 
         self.logger.info(f"MySQL({self._name}@{self._host}) connect")
 
+    def get_connection(self):
+        return MysqlDB.Connection(self.pool.connection())
+
     def search(self, sql: str, *args) -> Union[None, Result]:
         return self.__search(sql, args)
 
-    def insert(self, sql: str, *args) -> Union[None, Result]:
-        return self.__done(sql, args)
+    def insert(self, sql: str, *args, connection: Connection = None) -> Union[None, Result]:
+        return self.__done(sql, args, connection)
 
-    def delete(self, sql: str, *args) -> Union[None, Result]:
-        return self.__done(sql, args)
+    def delete(self, sql: str, *args, connection: Connection = None) -> Union[None, Result]:
+        return self.__done(sql, args, connection)
 
-    def update(self, sql: str, *args) -> Union[None, Result]:
-        return self.__done(sql, args)
+    def update(self, sql: str, *args, connection: Connection = None) -> Union[None, Result]:
+        return self.__done(sql, args, connection)
 
     def __search(self, sql, args) -> Union[None, Result]:
         conn = self.pool.connection()
@@ -80,20 +101,27 @@ class MysqlDB(Database):
             cur.close()
             conn.close()
 
-    def __done(self, sql, args) -> Union[None, Result]:
-        conn = self.pool.connection()
-        cur = conn.cursor()
+    def __done(self, sql, args, connection: Connection = None) -> Union[None, Result]:
+        if connection:
+            cur = connection.get_cursor()
+            conn = None
+        else:
+            conn = self.pool.connection()
+            cur = conn.cursor()
 
         try:
             cur.execute(query=sql, args=args)
-            conn.commit()
+            if conn:
+                conn.commit()
         except pymysql.MySQLError:
-            conn.rollback()
+            if conn:
+                conn.rollback()
             self.logger.error(f"MySQL({self._name}@{self._host}) SQL {sql} error {inspect.stack()[2][2]} "
                               f"{inspect.stack()[2][1]} {inspect.stack()[2][3]}", exc_info=True, stack_info=True)
             return None
         else:
             return MysqlDB.Result(cur)
         finally:
-            cur.close()
-            conn.close()
+            if not connection:
+                cur.close()
+                conn.close()

+ 29 - 6
sql/user.py

@@ -55,18 +55,33 @@ def delete_user(user_id: int, mysql: DB = db):
     delete_user_from_cache(get_user_email(user_id))
     delete_user_email_from_cache(user_id)
 
-    cur = mysql.delete("DELETE FROM message WHERE Auth=%s", user_id)
+    conn = mysql.get_connection()
+    cur = mysql.delete("DELETE FROM message WHERE Auth=%s", user_id, connection=conn)
     if cur is None:
+        conn.rollback()
+        conn.close()
         return False
-    cur = mysql.delete("DELETE FROM comment WHERE Auth=%s", user_id)
+
+    cur = mysql.delete("DELETE FROM comment WHERE Auth=%s", user_id, connection=conn)
     if cur is None:
+        conn.rollback()
+        conn.close()
         return False
-    cur = mysql.delete("DELETE FROM blog WHERE Auth=%s", user_id)
+
+    cur = mysql.delete("DELETE FROM blog WHERE Auth=%s", user_id, connection=conn)
     if cur is None:
+        conn.rollback()
+        conn.close()
         return False
-    cur = mysql.delete("DELETE FROM user WHERE ID=%s", user_id)
+
+    cur = mysql.delete("DELETE FROM user WHERE ID=%s", user_id, connection=conn)
     if cur is None or cur.rowcount == 0:
+        conn.rollback()
+        conn.close()
         return False
+
+    conn.commit()
+    conn.close()
     return True
 
 
@@ -108,16 +123,24 @@ def __authority_to_sql(authority):
 
 
 def create_role(name: str, authority: List[str], mysql: DB = db):
-    cur = db.insert("INSERT INTO role(RoleName) VALUES (%s)", name)
+    conn = mysql.get_connection()
+    cur = mysql.insert("INSERT INTO role(RoleName) VALUES (%s)", name, connection=conn)
     if cur is None or cur.rowcount == 0:
+        conn.rollback()
+        conn.close()
         return False
 
     sql, args = __authority_to_sql({i: (1 if i in authority else 0) for i in role_authority})
     cur = mysql.update(f"UPDATE role "
                        f"SET {sql} "
-                       f"WHERE RoleName=%s", *args, name)
+                       f"WHERE RoleName=%s", *args, name, connection=conn)
     if cur is None or cur.rowcount == 0:
+        conn.rollback()
+        conn.close()
         return False
+
+    conn.commit()
+    conn.close()
     return True