瀏覽代碼

feat: 定时刷新缓存

SongZihuan 2 年之前
父節點
當前提交
2ee32d5a0d
共有 9 個文件被更改,包括 387 次插入227 次删除
  1. 1 0
      configure/__init__.py
  2. 16 4
      main.py
  3. 52 38
      sql/archive.py
  4. 84 69
      sql/blog.py
  5. 5 3
      sql/cache.py
  6. 62 0
      sql/cache_refresh.py
  7. 39 20
      sql/comment.py
  8. 58 44
      sql/msg.py
  9. 70 49
      sql/user.py

+ 1 - 0
configure/__init__.py

@@ -26,6 +26,7 @@ conf = {
     "CACHE_REDIS_PASSWD": "123456",
     "CACHE_REDIS_DATABASE": 0,
     "CACHE_EXPIRE": 604800,  # 默认七天过期
+    "CACHE_REFRESH_INTERVAL": 432000,  # 缓存刷新时间  默认五天刷新一次
     "VIEW_CACHE_EXPIRE": 60,  # 视图函数
     "LIST_CACHE_EXPIRE": 5,  # 列表 排行
     "REDIS_EXPIRE": "statistics",

+ 16 - 4
main.py

@@ -1,7 +1,8 @@
-from configure import configure
+from configure import configure, conf
 
 import os
 import logging
+import threading
 
 env_dict = os.environ
 hblog_conf = env_dict.get("hblog_conf")
@@ -12,15 +13,26 @@ else:
     logging.info(f"Configure file {hblog_conf}")
     configure(hblog_conf)
 
-from sql.cache import restart_clear_cache
-restart_clear_cache()  # 清理缓存
-
 from app import HBlogFlask
 from waitress import serve
 
 app = HBlogFlask(__name__)
 app.register_all_blueprint()
 
+from sql.cache import restart_clear_cache
+from sql.cache_refresh import refresh
+restart_clear_cache()  # 清理缓存
+
+
+class FirstRefresh(threading.Thread):
+    def run(self):
+        refresh()
+
+first_refresh_th = FirstRefresh()
+first_refresh_th.start()
+refresh_th = threading.Timer(conf["CACHE_REFRESH_INTERVAL"], refresh)
+refresh_th.start()
+
 if __name__ == '__main__':
     logging.info("Server start on 127.0.0.1:8080")
     serve(app, host='0.0.0.0', port="8080")

+ 52 - 38
sql/archive.py

@@ -1,29 +1,30 @@
-from sql import db
+from sql import db, DB
 from sql.cache import (get_archive_from_cache, write_archive_to_cache, delete_archive_from_cache,
                        get_blog_archive_from_cache, write_blog_archive_to_cache, delete_blog_archive_from_cache,
                        delete_all_blog_archive_from_cache)
 from typing import Optional
 
 
-def create_archive(name: str, describe: str):
+def create_archive(name: str, describe: str, mysql: DB = db):
     """ 创建新归档 """
-    cur = db.insert("INSERT INTO archive(Name, DescribeText) "
-                    "VALUES (%s, %s)", name, describe)
+    cur = mysql.insert("INSERT INTO archive(Name, DescribeText) "
+                       "VALUES (%s, %s)", name, describe)
     if cur is None or cur.rowcount == 0:
         return None
-    read_archive(cur.lastrowid)
+    read_archive(cur.lastrowid, mysql)
     return cur.lastrowid
 
 
-def read_archive(archive_id: int):
+def read_archive(archive_id: int, mysql: DB = db, not_cache=False):
     """ 获取归档 ID """
-    res = get_archive_from_cache(archive_id)
-    if res is not None:
-        return res
-
-    cur = db.search("SELECT Name, DescribeText "
-                    "FROM archive "
-                    "WHERE ID=%s", archive_id)
+    if not not_cache:
+        res = get_archive_from_cache(archive_id)
+        if res is not None:
+            return res
+
+    cur = mysql.search("SELECT Name, DescribeText "
+                       "FROM archive "
+                       "WHERE ID=%s", archive_id)
     if cur is None or cur.rowcount == 0:
         return ["", ""]
 
@@ -32,15 +33,16 @@ def read_archive(archive_id: int):
     return res
 
 
-def get_blog_archive(blog_id: int):
+def get_blog_archive(blog_id: int, mysql: DB = db, not_cache=False):
     """ 获取文章的归档 """
-    res = get_blog_archive_from_cache(blog_id)
-    if res is not None:
-        return res
-
-    cur = db.search("SELECT ArchiveID FROM blog_archive_with_name "
-                    "WHERE BlogID=%s "
-                    "ORDER BY ArchiveName", blog_id)
+    if not not_cache:
+        res = get_blog_archive_from_cache(blog_id)
+        if res is not None:
+            return res
+        
+    cur = mysql.search("SELECT ArchiveID FROM blog_archive_with_name "
+                       "WHERE BlogID=%s "
+                       "ORDER BY ArchiveName", blog_id)
     if cur is None:
         return []
 
@@ -49,52 +51,64 @@ def get_blog_archive(blog_id: int):
     return res
 
 
-def delete_archive(archive_id: int):
+def delete_archive(archive_id: int, mysql: DB = db):
     delete_archive_from_cache(archive_id)
     delete_all_blog_archive_from_cache()
-    cur = db.delete("DELETE FROM blog_archive WHERE ArchiveID=%s", archive_id)
+    cur = mysql.delete("DELETE FROM blog_archive WHERE ArchiveID=%s", archive_id)
     if cur is None:
         return False
-    cur = db.delete("DELETE FROM archive WHERE ID=%s", archive_id)
+    cur = mysql.delete("DELETE FROM archive WHERE ID=%s", archive_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def add_blog_to_archive(blog_id: int, archive_id: int):
+def add_blog_to_archive(blog_id: int, archive_id: int, mysql: DB = db):
     delete_blog_archive_from_cache(blog_id)
-    cur = db.search("SELECT BlogID FROM blog_archive WHERE BlogID=%s AND ArchiveID=%s", blog_id, archive_id)
+    cur = mysql.search("SELECT BlogID FROM blog_archive WHERE BlogID=%s AND ArchiveID=%s", blog_id, archive_id)
     if cur is None:
         return False
     if cur.rowcount > 0:
         return True
-    cur = db.insert("INSERT INTO blog_archive(BlogID, ArchiveID) VALUES (%s, %s)", blog_id, archive_id)
+    cur = mysql.insert("INSERT INTO blog_archive(BlogID, ArchiveID) VALUES (%s, %s)", blog_id, archive_id)
     if cur is None or cur.rowcount != 1:
         return False
     return True
 
 
-def sub_blog_from_archive(blog_id: int, archive_id: int):
+def sub_blog_from_archive(blog_id: int, archive_id: int, mysql: DB = db):
     delete_blog_archive_from_cache(blog_id)
-    cur = db.delete("DELETE FROM blog_archive WHERE BlogID=%s AND ArchiveID=%s", blog_id, archive_id)
+    cur = mysql.delete("DELETE FROM blog_archive WHERE BlogID=%s AND ArchiveID=%s", blog_id, archive_id)
     if cur is None:
         return False
     return True
 
 
-def get_archive_list(limit: Optional[int] = None, offset: Optional[int] = None):
+def get_archive_list(limit: Optional[int] = None, offset: Optional[int] = None, mysql: DB = db):
     """ 获取归档列表 """
     if limit is not None and offset is not None:
-        cur = db.search("SELECT ID "
-                        "FROM archive "
-                        "ORDER BY Name "
-                        "LIMIT %s "
-                        "OFFSET %s ", limit, offset)
+        cur = mysql.search("SELECT ID "
+                           "FROM archive "
+                           "ORDER BY Name "
+                           "LIMIT %s "
+                           "OFFSET %s ", limit, offset)
     else:
-        cur = db.search("SELECT ID "
-                        "FROM archive "
-                        "ORDER BY Name")
+        cur = mysql.search("SELECT ID "
+                           "FROM archive "
+                           "ORDER BY Name")
 
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
+
+
+
+def get_archive_list_iter(mysql: DB = db):
+    """ 获取归档列表 """
+    cur = mysql.search("SELECT ID "
+                       "FROM archive "
+                       "ORDER BY Name")
+
+    if cur is None or cur.rowcount == 0:
+        return []
+    return cur

+ 84 - 69
sql/blog.py

@@ -1,4 +1,4 @@
-from sql import db
+from sql import db, DB
 from sql.base import DBBit
 from sql.archive import add_blog_to_archive
 from sql.cache import (write_blog_to_cache, get_blog_from_cache, delete_blog_from_cache,
@@ -10,19 +10,18 @@ from sql.cache import (write_blog_to_cache, get_blog_from_cache, delete_blog_fro
                        delete_blog_archive_from_cache)
 import object.archive
 
-
 from typing import Optional, List
 
 
 def create_blog(auth_id: int, title: str, subtitle: str, content: str,
-                archive_list: List[object.archive.Archive]) -> bool:
+                archive_list: List[object.archive.Archive], mysql: DB = db) -> bool:
     """ 写入新的blog """
     delete_blog_count_from_cache()
     delete_user_blog_count_from_cache(auth_id)
     # archive cache 在下面循环删除
 
-    cur = db.insert("INSERT INTO blog(Auth, Title, SubTitle, Content) "
-                    "VALUES (%s, %s, %s, %s)", auth_id, title, subtitle, content)
+    cur = mysql.insert("INSERT INTO blog(Auth, Title, SubTitle, Content) "
+                       "VALUES (%s, %s, %s, %s)", auth_id, title, subtitle, content)
     if cur is None or cur.rowcount == 0:
         return False
 
@@ -31,32 +30,33 @@ def create_blog(auth_id: int, title: str, subtitle: str, content: str,
         if not add_blog_to_archive(blog_id, archive.id):
             return False
         delete_archive_blog_count_from_cache(archive.id)
-    read_blog(blog_id)  # 刷新缓存
+    read_blog(blog_id, mysql)  # 刷新缓存
     return True
 
 
-def update_blog(blog_id: int, content: str) -> bool:
+def update_blog(blog_id: int, content: str, mysql: DB = db) -> bool:
     """ 更新博客文章 """
     delete_blog_from_cache(blog_id)
 
-    cur = db.update("Update blog "
-                    "SET UpdateTime=CURRENT_TIMESTAMP(), Content=%s "
-                    "WHERE ID=%s", content, blog_id)
+    cur = mysql.update("Update blog "
+                       "SET UpdateTime=CURRENT_TIMESTAMP(), Content=%s "
+                       "WHERE ID=%s", content, blog_id)
     if cur is None or cur.rowcount != 1:
         return False
-    read_blog(blog_id)  # 刷新缓存
+    read_blog(blog_id, mysql)  # 刷新缓存
     return True
 
 
-def read_blog(blog_id: int) -> list:
+def read_blog(blog_id: int, mysql: DB = db, not_cache=False) -> list:
     """ 读取blog内容 """
-    res = get_blog_from_cache(blog_id)
-    if res is not None:
-        return res
-
-    cur = db.search("SELECT Auth, Title, SubTitle, Content, UpdateTime, CreateTime, Top "
-                    "FROM blog "
-                    "WHERE ID=%s", blog_id)
+    if not not_cache:
+        res = get_blog_from_cache(blog_id)
+        if res is not None:
+            return res
+
+    cur = mysql.search("SELECT Auth, Title, SubTitle, Content, UpdateTime, CreateTime, Top "
+                       "FROM blog "
+                       "WHERE ID=%s", blog_id)
     if cur is None or cur.rowcount == 0:
         return [-1, "", "", "", 0, -1, False]
     res = cur.fetchone()
@@ -64,93 +64,106 @@ def read_blog(blog_id: int) -> list:
     return [*res[:6], res[-1] == DBBit.BIT_1]
 
 
-def delete_blog(blog_id: int):
+def delete_blog(blog_id: int, mysql: DB = db):
     delete_blog_count_from_cache()
     delete_all_archive_blog_count_from_cache()
     delete_all_user_blog_count_from_cache()
     delete_blog_from_cache(blog_id)
     delete_blog_archive_from_cache(blog_id)
 
-    cur = db.delete("DELETE FROM blog_archive WHERE BlogID=%s", blog_id)
+    cur = mysql.delete("DELETE FROM blog_archive WHERE BlogID=%s", blog_id)
     if cur is None:
         return False
-    cur = db.delete("DELETE FROM comment WHERE BlogID=%s", blog_id)
+    cur = mysql.delete("DELETE FROM comment WHERE BlogID=%s", blog_id)
     if cur is None:
         return False
-    cur = db.delete("DELETE FROM blog WHERE ID=%s", blog_id)
+    cur = mysql.delete("DELETE FROM blog WHERE ID=%s", blog_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def set_blog_top(blog_id: int, top: bool = True):
+def set_blog_top(blog_id: int, top: bool = True, mysql: DB = db):
     delete_blog_from_cache(blog_id)
-    cur = db.update("UPDATE blog "
-                    "SET Top=%s "
-                    "WHERE ID=%s", 1 if top else 0, blog_id)
+    cur = mysql.update("UPDATE blog "
+                       "SET Top=%s "
+                       "WHERE ID=%s", 1 if top else 0, blog_id)
     if cur is None or cur.rowcount != 1:
         return False
-    read_blog(blog_id)  # 刷新缓存
+    read_blog(blog_id, mysql)  # 刷新缓存
     return True
 
 
-def get_blog_list(limit: Optional[int] = None, offset: Optional[int] = None) -> list:
+def get_blog_list(limit: Optional[int] = None, offset: Optional[int] = None, mysql: DB = db) -> list:
     """ 获得 blog 列表 """
     if limit is not None and offset is not None:
-        cur = db.search("SELECT ID "
-                        "FROM blog "
-                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle "
-                        "LIMIT %s OFFSET %s", limit, offset)
+        cur = mysql.search("SELECT ID "
+                           "FROM blog "
+                           "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle "
+                           "LIMIT %s OFFSET %s", limit, offset)
     else:
-        cur = db.search("SELECT ID "
-                        "FROM blog "
-                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle")
+        cur = mysql.search("SELECT ID "
+                           "FROM blog "
+                           "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle")
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
 
 
-def get_blog_list_not_top(limit: Optional[int] = None, offset: Optional[int] = None) -> list:
+def get_blog_list_iter(mysql: DB = db):
+    """ 获得 blog 列表 """
+    cur = mysql.search("SELECT ID "
+                       "FROM blog "
+                       "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle")
+    if cur is None or cur.rowcount == 0:
+        return []
+    return cur
+
+
+def get_blog_list_not_top(limit: Optional[int] = None, offset: Optional[int] = None, mysql: DB = db) -> list:
     """ 获得blog列表 忽略置顶 """
     if limit is not None and offset is not None:
-        cur = db.search("SELECT ID "
-                        "FROM blog "
-                        "ORDER BY CreateTime DESC, Title, SubTitle "
-                        "LIMIT %s OFFSET %s", limit, offset)
+        cur = mysql.search("SELECT ID "
+                           "FROM blog "
+                           "ORDER BY CreateTime DESC, Title, SubTitle "
+                           "LIMIT %s OFFSET %s", limit, offset)
     else:
-        cur = db.search("SELECT ID "
-                        "FROM blog "
-                        "ORDER BY CreateTime DESC, Title, SubTitle")
+        cur = mysql.search("SELECT ID "
+                           "FROM blog "
+                           "ORDER BY CreateTime DESC, Title, SubTitle")
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
 
 
-def get_archive_blog_list(archive_id, limit: Optional[int] = None, offset: Optional[int] = None) -> list:
+def get_archive_blog_list(archive_id, limit: Optional[int] = None,
+                          offset: Optional[int] = None,
+                          mysql: DB = db) -> list:
     """ 获得指定归档的 blog 列表 """
     if limit is not None and offset is not None:
-        cur = db.search("SELECT BlogID "
-                        "FROM blog_with_archive "
-                        "WHERE ArchiveID=%s "
-                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle "
-                        "LIMIT %s OFFSET %s", archive_id, limit, offset)
+        cur = mysql.search("SELECT BlogID "
+                           "FROM blog_with_archive "
+                           "WHERE ArchiveID=%s "
+                           "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle "
+                           "LIMIT %s OFFSET %s", archive_id, limit, offset)
     else:
-        cur = db.search("SELECT BlogID "
-                        "FROM blog_with_archive "
-                        "WHERE ArchiveID=%s "
-                        "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle")
+        cur = mysql.search("SELECT BlogID "
+                           "FROM blog_with_archive "
+                           "WHERE ArchiveID=%s "
+                           "ORDER BY Top DESC, CreateTime DESC, Title, SubTitle")
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
 
 
-def get_blog_count() -> int:
+def get_blog_count(mysql: DB = db, not_cache=False) -> int:
     """ 统计 blog 个数 """
-    res = get_blog_count_from_cache()
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_blog_count_from_cache()
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT COUNT(*) FROM blog")
+    cur = mysql.search("SELECT COUNT(*) FROM blog")
     if cur is None or cur.rowcount == 0:
         return 0
 
@@ -159,13 +172,14 @@ def get_blog_count() -> int:
     return res
 
 
-def get_archive_blog_count(archive_id) -> int:
+def get_archive_blog_count(archive_id, mysql: DB = db, not_cache=False) -> int:
     """ 统计指定归档的 blog 个数 """
-    res = get_archive_blog_count_from_cache(archive_id)
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_archive_blog_count_from_cache(archive_id)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT COUNT(*) FROM blog_with_archive WHERE ArchiveID=%s", archive_id)
+    cur = mysql.search("SELECT COUNT(*) FROM blog_with_archive WHERE ArchiveID=%s", archive_id)
     if cur is None or cur.rowcount == 0:
         return 0
 
@@ -174,13 +188,14 @@ def get_archive_blog_count(archive_id) -> int:
     return res
 
 
-def get_user_blog_count(user_id: int) -> int:
+def get_user_blog_count(user_id: int, mysql: DB = db, not_cache=False) -> int:
     """ 获得指定用户的 blog 个数 """
-    res = get_user_blog_count_from_cache(user_id)
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_user_blog_count_from_cache(user_id)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT COUNT(*) FROM blog WHERE Auth=%s", user_id)
+    cur = mysql.search("SELECT COUNT(*) FROM blog WHERE Auth=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return 0
 

+ 5 - 3
sql/cache.py

@@ -1,4 +1,4 @@
-from sql import cache
+from sql import cache, DB
 from sql.base import DBBit
 from configure import conf
 
@@ -6,7 +6,6 @@ from redis import RedisError
 from functools import wraps
 from datetime import datetime
 
-
 CACHE_TIME = int(conf["CACHE_EXPIRE"])
 CACHE_PREFIX = conf["CACHE_PREFIX"]
 
@@ -21,7 +20,9 @@ def __try_redis(ret=None):
                 cache.logger.error(f"Redis error with {args} {kwargs}", exc_info=True, stack_info=True)
                 return ret
             return res
+
         return try_func
+
     return try_redis
 
 
@@ -247,7 +248,8 @@ def write_blog_archive_to_cache(blog_id: int, archive):
     cache.delete(cache_name)
     if len(archive) == 0:
         cache.rpush(cache_name, -1)
-    cache.rpush(cache_name, *archive)
+    else:
+        cache.rpush(cache_name, *archive)
     cache.expire(cache_name, CACHE_TIME)
 
 

+ 62 - 0
sql/cache_refresh.py

@@ -0,0 +1,62 @@
+from sql import DB
+from configure import conf
+from sql.archive import read_archive, get_archive_list_iter, get_blog_archive
+from sql.blog import read_blog, get_blog_count, get_archive_blog_count, get_user_blog_count, get_blog_list_iter
+from sql.comment import read_comment, read_comment_list_iter, get_user_comment_count
+from sql.msg import read_msg, read_msg_list_iter, get_msg_count, get_user_msg_count
+from sql.user import (read_user, get_user_list_iter, get_role_list_iter,
+                      get_user_email, get_role_name, check_role, role_authority)
+import logging.handlers
+import os
+
+refresh_logger = logging.getLogger("main.refresh")
+refresh_logger.setLevel(conf["LOG_LEVEL"])
+if len(conf["LOG_HOME"]) > 0:
+    handle = logging.handlers.TimedRotatingFileHandler(
+        os.path.join(conf["LOG_HOME"], f"redis-refresh.log"), backupCount=10)
+    handle.setFormatter(logging.Formatter(conf["LOG_FORMAT"]))
+    refresh_logger.addHandler(handle)
+
+
+def refresh():
+    mysql = DB(host=conf["MYSQL_URL"],
+               name=conf["MYSQL_NAME"],
+               passwd=conf["MYSQL_PASSWD"],
+               port=conf["MYSQL_PORT"],
+               database=conf["MYSQL_DATABASE"])
+
+    refresh_logger.info("refresh redis cache started.")
+
+    for i in get_archive_list_iter():
+        read_archive(i[0], mysql, not_cache=True)
+        get_archive_blog_count(i[0], mysql, not_cache=True)
+
+    for i in get_blog_list_iter():
+        read_blog(i[0], mysql, not_cache=True)
+        get_blog_archive(i[0], mysql, not_cache=True)
+    get_blog_count(mysql, not_cache=True)
+
+    for i in read_comment_list_iter():
+        read_comment(i[0], mysql, not_cache=True)
+        print(f"comment {i}")
+
+    for i in read_msg_list_iter():
+        read_msg(i[0], mysql, not_cache=True)
+        print(f"msg {i}")
+    get_msg_count(mysql, not_cache=True)
+
+    for i in get_user_list_iter():
+        email = get_user_email(i[0], mysql, not_cache=True)
+        get_user_blog_count(i[0], mysql, not_cache=True)
+        get_user_comment_count(i[0], mysql, not_cache=True)
+        get_user_msg_count(i[0], mysql, not_cache=True)
+        read_user(email, mysql, not_cache=True)
+        print(f"user: {i}")
+
+    for i in get_role_list_iter():
+        get_role_name(i[0], mysql, not_cache=True)
+        for a in role_authority:
+            check_role(i[0], a, mysql, not_cache=True)
+        print(f"role {i}")
+
+    refresh_logger.info("refresh redis cache finished.")

+ 39 - 20
sql/comment.py

@@ -1,39 +1,50 @@
-from sql import db
+from sql import db, DB
 from sql.cache import (get_comment_from_cache, write_comment_to_cache, delete_comment_from_cache,
                        get_user_comment_count_from_cache, write_user_comment_count_to_cache,
                        delete_all_user_comment_count_from_cache, delete_user_comment_count_from_cache)
 
 
-def read_comment_list(blog_id: int):
+def read_comment_list(blog_id: int, mysql: DB = db):
     """ 读取文章的 comment """
-    cur = db.search("SELECT CommentID "
-                    "FROM comment_user "
-                    "WHERE BlogID=%s "
-                    "ORDER BY UpdateTime DESC", blog_id)
+    cur = mysql.search("SELECT CommentID "
+                       "FROM comment_user "
+                       "WHERE BlogID=%s "
+                       "ORDER BY UpdateTime DESC", blog_id)
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
 
 
-def create_comment(blog_id: int, user_id: int, content: str):
+def read_comment_list_iter(mysql: DB = db):
+    """ 读取文章的 comment """
+    cur = mysql.search("SELECT CommentID "
+                       "FROM comment_user "
+                       "ORDER BY UpdateTime DESC")
+    if cur is None or cur.rowcount == 0:
+        return []
+    return cur
+
+
+def create_comment(blog_id: int, user_id: int, content: str, mysql: DB = db):
     """ 新建 comment """
     delete_user_comment_count_from_cache(user_id)
 
-    cur = db.insert("INSERT INTO comment(BlogID, Auth, Content) "
-                    "VALUES (%s, %s, %s)", blog_id, user_id, content)
+    cur = mysql.insert("INSERT INTO comment(BlogID, Auth, Content) "
+                       "VALUES (%s, %s, %s)", blog_id, user_id, content)
     if cur is None or cur.rowcount == 0:
         return False
-    read_comment(cur.lastrowid)  # 刷新缓存
+    read_comment(cur.lastrowid, mysql)  # 刷新缓存
     return True
 
 
-def read_comment(comment_id: int):
+def read_comment(comment_id: int, mysql: DB = db, not_cache=False):
     """ 读取 comment """
-    res = get_comment_from_cache(comment_id)
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_comment_from_cache(comment_id)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT BlogID, Email, Content, UpdateTime FROM comment_user WHERE CommentID=%s", comment_id)
+    cur = mysql.search("SELECT BlogID, Email, Content, UpdateTime FROM comment_user WHERE CommentID=%s", comment_id)
     if cur is None or cur.rowcount == 0:
         return [-1, "", "", 0]
 
@@ -42,19 +53,27 @@ def read_comment(comment_id: int):
     return res
 
 
-def delete_comment(comment_id):
+def delete_comment(comment_id: int, mysql: DB = db):
     """ 删除评论 """
     delete_comment_from_cache(comment_id)
     delete_all_user_comment_count_from_cache()
-    cur = db.delete("DELETE FROM comment WHERE ID=%s", comment_id)
+    cur = mysql.delete("DELETE FROM comment WHERE ID=%s", comment_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def get_user_comment_count(user_id: int):
+def get_user_comment_count(user_id: int, mysql: DB = db, not_cache=False):
     """ 读取指定用户的 comment 个数 """
-    cur = db.search("SELECT COUNT(*) FROM comment WHERE Auth=%s", user_id)
+    if not not_cache:
+        res = get_user_comment_count_from_cache(user_id)
+        if res is not None:
+            return res
+
+    cur = mysql.search("SELECT COUNT(*) FROM comment WHERE Auth=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return 0
-    return cur.fetchone()[0]
+
+    res = cur.fetchone()[0]
+    write_user_comment_count_to_cache(user_id, res)
+    return res

+ 58 - 44
sql/msg.py

@@ -1,4 +1,4 @@
-from sql import db
+from sql import db, DB
 from sql.base import DBBit
 from sql.cache import (get_msg_from_cache, write_msg_to_cache, delete_msg_from_cache,
                        get_msg_cout_from_cache, write_msg_count_to_cache, delete_msg_count_from_cache,
@@ -8,57 +8,69 @@ from sql.cache import (get_msg_from_cache, write_msg_to_cache, delete_msg_from_c
 from typing import Optional
 
 
-
-def read_msg_list(limit: Optional[int] = None, offset: Optional[int] = None, show_secret: bool = False):
+def read_msg_list(limit: Optional[int] = None,
+                  offset: Optional[int] = None,
+                  show_secret: bool = False,
+                  mysql: DB = db):
     if show_secret:
         if limit is not None and offset is not None:
-            cur = db.search("SELECT MsgID "
-                            "FROM message_user "
-                            "ORDER BY UpdateTime DESC "
-                            "LIMIT %s "
-                            "OFFSET %s", limit, offset)
+            cur = mysql.search("SELECT MsgID "
+                               "FROM message_user "
+                               "ORDER BY UpdateTime DESC "
+                               "LIMIT %s "
+                               "OFFSET %s", limit, offset)
         else:
-            cur = db.search("SELECT MsgID "
-                            "FROM message_user "
-                            "ORDER BY UpdateTime DESC")
+            cur = mysql.search("SELECT MsgID "
+                               "FROM message_user "
+                               "ORDER BY UpdateTime DESC")
     else:
         if limit is not None and offset is not None:
-            cur = db.search("SELECT MsgID "
-                            "FROM message_user "
-                            "WHERE Secret=0 "
-                            "ORDER BY UpdateTime DESC "
-                            "LIMIT %s "
-                            "OFFSET %s", limit, offset)
+            cur = mysql.search("SELECT MsgID "
+                               "FROM message_user "
+                               "WHERE Secret=0 "
+                               "ORDER BY UpdateTime DESC "
+                               "LIMIT %s "
+                               "OFFSET %s", limit, offset)
         else:
-            cur = db.search("SELECT MsgID "
-                            "FROM message_user "
-                            "WHERE Secret=0 "
-                            "ORDER BY UpdateTime DESC")
+            cur = mysql.search("SELECT MsgID "
+                               "FROM message_user "
+                               "WHERE Secret=0 "
+                               "ORDER BY UpdateTime DESC")
     if cur is None or cur.rowcount == 0:
         return []
     return [i[0] for i in cur.fetchall()]
 
 
-def create_msg(auth: int, content: str, secret: bool = False):
+def read_msg_list_iter(mysql: DB = db):
+    cur = mysql.search("SELECT MsgID "
+                       "FROM message_user "
+                       "ORDER BY UpdateTime DESC")
+    if cur is None or cur.rowcount == 0:
+        return []
+    return cur
+
+
+def create_msg(auth: int, content: str, secret: bool = False, mysql: DB = db):
     delete_msg_count_from_cache()
     delete_user_msg_count_from_cache(auth)
 
-    cur = db.insert("INSERT INTO message(Auth, Content, Secret) "
-                    "VALUES (%s, %s, %s)", auth, content, 1 if secret else 0)
+    cur = mysql.insert("INSERT INTO message(Auth, Content, Secret) "
+                       "VALUES (%s, %s, %s)", auth, content, 1 if secret else 0)
     if cur is None or cur.rowcount != 1:
         return None
-    read_msg(cur.lastrowid)  # 刷新缓存
+    read_msg(cur.lastrowid, mysql)  # 刷新缓存
     return cur.lastrowid
 
 
-def read_msg(msg_id: int):
-    res = get_msg_from_cache(msg_id)
-    if res is not None:
-        return res
+def read_msg(msg_id: int, mysql: DB = db, not_cache=False):
+    if not not_cache:
+        res = get_msg_from_cache(msg_id)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT Email, Content, UpdateTime, Secret "
-                    "FROM message_user "
-                    "WHERE MsgID=%s", msg_id)
+    cur = mysql.search("SELECT Email, Content, UpdateTime, Secret "
+                       "FROM message_user "
+                       "WHERE MsgID=%s", msg_id)
     if cur is None or cur.rowcount == 0:
         return ["", "", "0", False]
 
@@ -67,22 +79,23 @@ def read_msg(msg_id: int):
     return [*res[:3], res[-1] == DBBit.BIT_1]
 
 
-def delete_msg(msg_id: int):
+def delete_msg(msg_id: int, mysql: DB = db):
     delete_msg_from_cache(msg_id)
     delete_msg_count_from_cache()
     delete_all_user_msg_count_from_cache()
-    cur = db.delete("DELETE FROM message WHERE ID=%s", msg_id)
+    cur = mysql.delete("DELETE FROM message WHERE ID=%s", msg_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def get_msg_count():
-    res = get_msg_cout_from_cache()
-    if res is not None:
-        return res
+def get_msg_count(mysql: DB = db, not_cache=False):
+    if not not_cache:
+        res = get_msg_cout_from_cache()
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT COUNT(*) FROM message")
+    cur = mysql.search("SELECT COUNT(*) FROM message")
     if cur is None or cur.rowcount == 0:
         return 0
     res = cur.fetchone()[0]
@@ -90,12 +103,13 @@ def get_msg_count():
     return res
 
 
-def get_user_msg_count(user_id: int):
-    res = get_user_msg_count_from_cache(user_id)
-    if res is not None:
-        return res
+def get_user_msg_count(user_id: int, mysql: DB = db, not_cache=False):
+    if not not_cache:
+        res = get_user_msg_count_from_cache(user_id)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT COUNT(*) FROM message WHERE Auth=%s", user_id)
+    cur = mysql.search("SELECT COUNT(*) FROM message WHERE Auth=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return 0
     res = cur.fetchone()[0]

+ 70 - 49
sql/user.py

@@ -1,4 +1,4 @@
-from sql import db
+from sql import db, DB
 from sql.base import DBBit
 from sql.cache import (get_user_from_cache, write_user_to_cache, delete_user_from_cache,
                        get_user_email_from_cache, write_user_email_to_cache, delete_user_email_from_cache,
@@ -14,13 +14,14 @@ role_authority = ["WriteBlog", "WriteComment", "WriteMsg", "CreateUser",
                   "ConfigureSystem", "ReadSystem"]
 
 
-def read_user(email: str):
+def read_user(email: str, mysql: DB = db, not_cache=False):
     """ 读取用户 """
-    res = get_user_from_cache(email)
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_user_from_cache(email)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT PasswdHash, Role, ID FROM user WHERE Email=%s", email)
+    cur = mysql.search("SELECT PasswdHash, Role, ID FROM user WHERE Email=%s", email)
     if cur is None or cur.rowcount != 1:
         return ["", -1, -1]
 
@@ -29,63 +30,65 @@ def read_user(email: str):
     return res
 
 
-def create_user(email: str, passwd: str):
+def create_user(email: str, passwd: str, mysql: DB = db):
     """ 创建用户 """
     if len(email) == 0:
         return None
 
-    cur = db.search("SELECT COUNT(*) FROM user")
+    cur = mysql.search("SELECT COUNT(*) FROM user")
     passwd = object.user.User.get_passwd_hash(passwd)
     if cur is None or cur.rowcount == 0 or cur.fetchone()[0] == 0:
         # 创建为管理员用户
-        cur = db.insert("INSERT INTO user(Email, PasswdHash, Role) "
-                        "VALUES (%s, %s, %s)", email, passwd, 1)
+        cur = mysql.insert("INSERT INTO user(Email, PasswdHash, Role) "
+                           "VALUES (%s, %s, %s)", email, passwd, 1)
     else:
-        cur = db.insert("INSERT INTO user(Email, PasswdHash) "
-                        "VALUES (%s, %s)", email, passwd)
+        cur = mysql.insert("INSERT INTO user(Email, PasswdHash) "
+                           "VALUES (%s, %s)", email, passwd)
     if cur is None or cur.rowcount != 1:
         return None
-    read_user(cur.lastrowid)  # 刷新缓存
+    read_user(email, mysql)  # 刷新缓存
     return cur.lastrowid
 
 
-def delete_user(user_id: int):
+def delete_user(user_id: int, mysql: DB = db):
     """ 删除用户 """
     delete_user_from_cache(get_user_email(user_id))
     delete_user_email_from_cache(user_id)
 
-    cur = db.delete("DELETE FROM message WHERE Auth=%s", user_id)
+    cur = mysql.delete("DELETE FROM message WHERE Auth=%s", user_id)
     if cur is None:
         return False
-    cur = db.delete("DELETE FROM comment WHERE Auth=%s", user_id)
+    cur = mysql.delete("DELETE FROM comment WHERE Auth=%s", user_id)
     if cur is None:
         return False
-    cur = db.delete("DELETE FROM blog WHERE Auth=%s", user_id)
+    cur = mysql.delete("DELETE FROM blog WHERE Auth=%s", user_id)
     if cur is None:
         return False
-    cur = db.delete("DELETE FROM user WHERE ID=%s", user_id)
+    cur = mysql.delete("DELETE FROM user WHERE ID=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def change_passwd_hash(user_email: str, passwd_hash: str):
+def change_passwd_hash(user_email: str, passwd_hash: str, mysql: DB = db):
     delete_user_from_cache(user_email)
-    cur = db.update("UPDATE user "
-                    "SET PasswdHash=%s "
-                    "WHERE Email=%s", passwd_hash, user_email)
+    cur = mysql.update("UPDATE user "
+                       "SET PasswdHash=%s "
+                       "WHERE Email=%s", passwd_hash, user_email)
+    read_user(user_email, mysql)  # 刷新缓存
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def get_user_email(user_id):
+def get_user_email(user_id, mysql: DB = db, not_cache=False):
     """ 获取用户邮箱 """
-    res = get_user_email_from_cache(user_id)
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_user_email_from_cache(user_id)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT Email FROM user WHERE ID=%s", user_id)
+    cur = mysql.search("SELECT Email FROM user WHERE ID=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return None
 
@@ -104,46 +107,47 @@ def __authority_to_sql(authority):
     return ",".join(sql), args
 
 
-def create_role(name: str, authority: List[str]):
+def create_role(name: str, authority: List[str], mysql: DB = db):
     cur = db.insert("INSERT INTO role(RoleName) VALUES (%s)", name)
     if cur is None or cur.rowcount == 0:
         return False
 
     sql, args = __authority_to_sql({i: (1 if i in authority else 0) for i in role_authority})
-    cur = db.update(f"UPDATE role "
-                    f"SET {sql} "
-                    f"WHERE RoleName=%s", *args, name)
+    cur = mysql.update(f"UPDATE role "
+                       f"SET {sql} "
+                       f"WHERE RoleName=%s", *args, name)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def delete_role(role_id: int):
+def delete_role(role_id: int, mysql: DB = db):
     delete_role_name_from_cache(role_id)
     delete_role_operate_from_cache(role_id)
 
-    cur = db.delete("DELETE FROM role WHERE RoleID=%s", role_id)
+    cur = mysql.delete("DELETE FROM role WHERE RoleID=%s", role_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def set_user_role(role_id: int, user_id: str):
-    cur = db.update("UPDATE user "
-                    "SET Role=%s "
-                    "WHERE ID=%s", role_id, user_id)
+def set_user_role(role_id: int, user_id: str, mysql: DB = db):
+    cur = mysql.update("UPDATE user "
+                       "SET Role=%s "
+                       "WHERE ID=%s", role_id, user_id)
     if cur is None or cur.rowcount == 0:
         return False
     return True
 
 
-def get_role_name(role: int):
+def get_role_name(role: int, mysql: DB = db, not_cache=False):
     """ 获取用户角色名称 """
-    res = get_role_name_from_cache(role)
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_role_name_from_cache(role)
+        if res is not None:
+            return res
 
-    cur = db.search("SELECT RoleName FROM role WHERE RoleID=%s", role)
+    cur = mysql.search("SELECT RoleName FROM role WHERE RoleID=%s", role)
     if cur is None or cur.rowcount == 0:
         return None
 
@@ -156,16 +160,17 @@ def __check_operate(operate):
     return operate in role_authority
 
 
-def check_role(role: int, operate: str):
+def check_role(role: int, operate: str, mysql: DB = db, not_cache=False):
     """ 检查角色权限(通过角色ID) """
     if not __check_operate(operate):  # 检查, 防止SQL注入
         return False
 
-    res = get_role_operate_from_cache(role, operate)
-    if res is not None:
-        return res
+    if not not_cache:
+        res = get_role_operate_from_cache(role, operate)
+        if res is not None:
+            return res
 
-    cur = db.search(f"SELECT {operate} FROM role WHERE RoleID=%s", role)
+    cur = mysql.search(f"SELECT {operate} FROM role WHERE RoleID=%s", role)
     if cur is None or cur.rowcount == 0:
         return False
 
@@ -174,9 +179,25 @@ def check_role(role: int, operate: str):
     return res
 
 
-def get_role_list():
+def get_role_list(mysql: DB = db):
     """ 获取归档列表 """
-    cur = db.search("SELECT RoleID, RoleName FROM role")
+    cur = mysql.search("SELECT RoleID, RoleName FROM role")
     if cur is None or cur.rowcount == 0:
         return []
     return cur.fetchall()
+
+
+def get_role_list_iter(mysql: DB = db):
+    """ 获取归档列表 """
+    cur = mysql.search("SELECT RoleID, RoleName FROM role")
+    if cur is None or cur.rowcount == 0:
+        return []
+    return cur
+
+
+def get_user_list_iter(mysql: DB = db):
+    """ 获取归档列表 """
+    cur = mysql.search("SELECT ID FROM user")
+    if cur is None or cur.rowcount == 0:
+        return []
+    return cur