Răsfoiți Sursa

feat: 留言启用redis缓存机制

SongZihuan 2 ani în urmă
părinte
comite
adf95a96c3
7 a modificat fișierele cu 153 adăugiri și 5 ștergeri
  1. 5 0
      configure/__init__.py
  2. BIN
      requirements.txt
  3. 7 1
      sql/__init__.py
  4. 87 0
      sql/cache.py
  5. 32 4
      sql/msg.py
  6. 2 0
      sql/mysql.py
  7. 20 0
      sql/redis.py

+ 5 - 0
configure/__init__.py

@@ -15,6 +15,11 @@ conf = {
     "MYSQL_PASSWD": "123456",
     "MYSQL_PORT": 3306,
     "MYSQL_DATABASE": "HBlog",
+    "REDIS_HOST": "localhost",
+    "REDIS_PORT": 6379,
+    "REDIS_NAME": "localhost",
+    "REDIS_PASSWD": "123456",
+    "REDIS_DATABASE": 0,
     "MAIL_SERVER": "",
     "MAIL_PORT": "",
     "MAIL_USE_TLS": False,

BIN
requirements.txt


+ 7 - 1
sql/__init__.py

@@ -1,10 +1,16 @@
 from sql.mysql import MysqlDB
+from sql.redis import RedisDB
 from configure import conf
 
-
 DB = MysqlDB
 db = DB(host=conf["MYSQL_URL"],
         name=conf["MYSQL_NAME"],
         passwd=conf["MYSQL_PASSWD"],
         port=conf["MYSQL_PORT"],
         database=conf["MYSQL_DATABASE"])
+
+cache = redis.RedisDB(host=conf["REDIS_HOST"],
+                      port=conf["REDIS_PORT"],
+                      username=conf["REDIS_NAME"],
+                      passwd=conf["REDIS_PASSWD"],
+                      db=conf["REDIS_DATABASE"])

+ 87 - 0
sql/cache.py

@@ -0,0 +1,87 @@
+from sql import cache
+
+from redis import RedisError
+from functools import wraps
+from datetime import datetime
+
+
+def __try_redis(ret=None):
+    def try_redis(func):
+        @wraps(func)
+        def try_func(*args, **kwargs):
+            try:
+                res = func(*args, **kwargs)
+            except RedisError:
+                cache.logger.error(f"Redis error with {args} {kwargs}", exc_info=True, stack_info=True)
+                return ret
+            return res
+        return try_func
+    return try_redis
+
+
+@__try_redis(None)
+def read_msg_from_cache(msg_id: int):
+    msg = cache.hgetall(f"cache:msg:{msg_id}")
+    if len(msg) != 4:
+        return None
+    return [msg.get("Email", ""), msg.get("Content"), msg.get("UpdateTime", "0"), bool(msg.get("Secret", False))]
+
+
+@__try_redis(None)
+def write_msg_to_cache(msg_id: int, email: str, content: str, update_time: str | datetime, secret: bool):
+    cache_name = f"cache:msg:{msg_id}"
+    cache.delete(cache_name)
+    cache.hset(cache_name, mapping={
+        "Email": email,
+        "Content": content,
+        "UpdateTime": str(update_time),
+        "Secret": str(secret)
+    })
+    cache.expire(cache_name, 3600)
+
+
+@__try_redis(None)
+def delete_msg_from_cache(msg_id: int):
+    cache.delete(f"cache:msg:{msg_id}")
+
+
+@__try_redis(None)
+def get_msg_cout_from_cache():
+    count = cache.get("cache:msg_count")
+    if count is not None:
+        return int(count)
+    return
+
+
+@__try_redis(None)
+def write_msg_count_to_cache(count):
+    count = cache.set("cache:msg_count", str(count))
+    cache.expire("cache:msg_count", 3600)
+    return count
+
+
+@__try_redis(None)
+def delete_msg_count_from_cache():
+    cache.delete("cache:msg_count")
+
+
+@__try_redis(None)
+def get_user_msg_cout_from_cache(user_id: int):
+    count = cache.get(f"cache:msg_count:{user_id}")
+    if count is not None:
+        return int(count)
+    return
+
+
+@__try_redis(None)
+def write_user_msg_count_to_cache(user_id, count):
+    cache_name = f"cache:msg_count:{user_id}"
+    count = cache.set(cache_name, str(count))
+    cache.expire(cache_name, 3600)
+    return count
+
+
+@__try_redis(None)
+def delete_all_user_msg_count_from_cache():
+    for i in cache.keys("cache:msg_count:*"):
+        cache.delete(i)

+ 32 - 4
sql/msg.py

@@ -1,7 +1,13 @@
 from sql import db
+from sql.cache import (read_msg_from_cache, write_msg_to_cache, delete_msg_from_cache,
+                       get_msg_cout_from_cache, write_msg_count_to_cache, delete_msg_count_from_cache,
+                       get_user_msg_cout_from_cache, write_user_msg_count_to_cache,
+                       delete_all_user_msg_count_from_cache)
+
 from typing import Optional
 
 
+
 def read_msg_list(limit: Optional[int] = None, offset: Optional[int] = None, show_secret: bool = False):
     if show_secret:
         if limit is not None and offset is not None:
@@ -42,15 +48,25 @@ def create_msg(auth: int, content: str, secret: bool = False):
 
 
 def read_msg(msg_id: int):
+    res = read_msg_from_cache(msg_id)
+    if res is not None:
+        return res
+
     cur = db.search("SELECT Email, Content, UpdateTime, Secret "
                     "FROM message_user "
                     "WHERE MsgID=%s", msg_id)
     if cur is None or cur.rowcount == 0:
-        return ["", "", 0, False]
-    return cur.fetchone()
+        return ["", "", "0", False]
+
+    res = cur.fetchone()
+    write_msg_to_cache(msg_id, *res)
+    return res
 
 
 def delete_msg(msg_id: int):
+    delete_msg_from_cache(msg_id)
+    delete_msg_count_from_cache()
+    delete_all_user_msg_count_from_cache()
     cur = db.delete("DELETE FROM message WHERE ID=%s", msg_id)
     if cur is None or cur.rowcount == 0:
         return False
@@ -58,14 +74,26 @@ def delete_msg(msg_id: int):
 
 
 def get_msg_count():
+    res = get_msg_cout_from_cache()
+    if res is not None:
+        return res
+
     cur = db.search("SELECT COUNT(*) FROM message")
     if cur is None or cur.rowcount == 0:
         return 0
-    return cur.fetchone()[0]
+    res = cur.fetchone()[0]
+    write_msg_count_to_cache(res)
+    return res
 
 
 def get_user_msg_count(user_id: int):
+    res = get_user_msg_cout_from_cache(user_id)
+    if res is not None:
+        return res
+
     cur = db.search("SELECT COUNT(*) FROM message WHERE Auth=%s", user_id)
     if cur is None or cur.rowcount == 0:
         return 0
-    return cur.fetchone()[0]
+    res = cur.fetchone()[0]
+    write_user_msg_count_to_cache(user_id, res)
+    return res

+ 2 - 0
sql/mysql.py

@@ -31,6 +31,8 @@ class MysqlDB(Database):
                                        database=self.database)
         except pymysql.err.OperationalError:
             raise
+
+        # mysql 不算线程安全的
         self._cursor = self._db.cursor()
         self._lock = threading.RLock()
         self.logger.info(f"MySQL({self._name}@{self._host}) connect")

+ 20 - 0
sql/redis.py

@@ -0,0 +1,20 @@
+import redis
+import logging
+import logging.handlers
+from configure import conf
+import os
+
+
+class RedisDB(redis.StrictRedis):
+    def __init__(self, host, port, username, passwd, db):
+        super().__init__(host=host, port=port, username=username, password=passwd, db=db, decode_responses=True)
+
+        # redis是线程安全的
+
+        self.logger = logging.getLogger("main.database")
+        self.logger.setLevel(conf["LOG_LEVEL"])
+        if len(conf["LOG_HOME"]) > 0:
+            handle = logging.handlers.TimedRotatingFileHandler(
+                os.path.join(conf["LOG_HOME"], f"redis-{username}@{host}.log"), backupCount=10)
+            handle.setFormatter(logging.Formatter(conf["LOG_FORMAT"]))
+            self.logger.addHandler(handle)