Przeglądaj źródła

feat: user表加锁

SongZihuan 3 lat temu
rodzic
commit
6a7d96f78f
14 zmienionych plików z 250 dodań i 139 usunięć
  1. 40 8
      app/auth/views.py
  2. 3 17
      app/news/views.py
  3. 11 32
      app/store/views.py
  4. 5 3
      app/web.py
  5. 2 1
      app/web_goods.py
  6. 58 38
      app/web_user.py
  7. 45 5
      core/user.py
  8. 1 7
      equipment/scan_user.py
  9. 2 1
      init.sql
  10. 9 2
      sql/mysql_db.py
  11. 31 8
      sql/user.py
  12. 2 0
      tk_ui/admin.py
  13. 15 7
      tk_ui/station.py
  14. 26 10
      tk_ui/station_event.py

+ 40 - 8
app/auth/views.py

@@ -1,5 +1,5 @@
 import math
-
+import functools
 from flask import render_template, Blueprint, Flask, request, url_for, redirect, flash, abort
 from wtforms import StringField, PasswordField, SubmitField
 from wtforms.validators import DataRequired
@@ -34,6 +34,40 @@ class LoginForm(FlaskForm):
     submit = SubmitField("登录")
 
 
+def web_user_required(f):
+    @login_required
+    @functools.wraps(f)
+    def func(*args, **kwargs):
+        if not current_user.update_info():
+            logout()
+            flash("用户错误")
+            abort(403)
+        return f(*args, **kwargs)
+
+    return func
+
+
+def manager_required(f):
+    """
+    管理员权限
+    :return:
+    """
+
+    @login_required
+    @functools.wraps(f)
+    def func(*args, **kwargs):
+        if not current_user.update_info():
+            logout()
+            flash("用户错误")
+            abort(403)
+
+        if not current_user.is_manager():
+            abort(403)
+        return f(*args, **kwargs)
+
+    return func
+
+
 @auth.route('/login', methods=['GET', 'POST'])
 def login():
     form = LoginForm()
@@ -55,7 +89,7 @@ def login():
 
 
 @auth.route("/logout")
-@login_required
+@web_user_required
 def logout():
     logout_user()
     flash("用户退出成功")
@@ -63,10 +97,9 @@ def logout():
 
 
 @auth.route("/about")
-@login_required
+@web_user_required
 def about():
     user: web_user.WebUser = current_user
-    user.update_info()
 
     try:
         page = int(request.args.get("page", 1))  # page 指垃圾袋的分页信息
@@ -82,17 +115,16 @@ def about():
 
 
 @auth.route("/order")
-@login_required
+@web_user_required
 def order_qr():
     """
     生成取件码和确认码
     图像临时保存在 BytesIO 中
     然后转换为Base64显示
     """
-    user: web_user.WebUser = current_user
-    user.update_info()
 
-    order, user, token = user.get_qr_code()  # 订单号, 用户ID, 确认码
+    user: web_user.WebUser = current_user
+    order, user, token = user.get_order_qr_code()  # 订单号, 用户ID, 确认码
 
     check_image = qrcode.make(data=url_for("store.check", user=user, order=order, _external=True))
     check_img_buffer = BytesIO()

+ 3 - 17
app/news/views.py

@@ -9,6 +9,7 @@ from tool.typing import Optional
 
 from app import views
 from app.web_user import WebUser
+from app.auth import views as auth_views
 
 news = Blueprint("news", __name__)
 app: Optional[Flask] = None
@@ -30,7 +31,7 @@ class NewDelete(FlaskForm):
 
 
 @news.route('/', methods=['GET', 'POST'])
-@login_required
+@auth_views.web_user_required
 def index():
     """
     Get请求时: 显示(获取)新闻消息
@@ -56,23 +57,8 @@ def index():
                            page_list=page_list, page=f"{page}", news_delete=delete_form)
 
 
-def manager_required(f):
-    """
-    检查是否有管理员权限
-    """
-
-    @functools.wraps(f)
-    def func(*args, **kwargs):
-        if not current_user.is_manager():
-            abort(403)
-        return f(*args, **kwargs)
-
-    return func
-
-
 @news.route('/delete', methods=['POST'])
-@login_required
-@manager_required
+@auth_views.manager_required
 def delete():
     """
     管理员: 删除内容

+ 11 - 32
app/store/views.py

@@ -3,14 +3,13 @@ from wtforms import TextField, SubmitField
 from flask_login import current_user
 from wtforms.validators import DataRequired
 from flask_wtf import FlaskForm
-from flask_login import login_required
-import functools
 from itsdangerous import TimedJSONWebSignatureSerializer as Serializer
 
 from conf import Config
 from tool.typing import Optional
 from app import views
-from app import web_user
+
+from app.auth import views as auth_views
 
 store = Blueprint("store", __name__)
 app: Optional[Flask] = None
@@ -35,7 +34,7 @@ class AddNewGoodsForm(FlaskForm):
 
 
 @store.route('/', methods=['GET', 'POST'])
-@login_required
+@auth_views.web_user_required
 def index():
     """
     显示购买的表单
@@ -44,13 +43,11 @@ def index():
     """
     form = BuySetForm()
     store_list = views.website.get_store_list()
-    user: web_user.WebUser = current_user
-    user.update_info()
     return render_template("store/store.html", store_list=store_list, store_form=form)
 
 
 @store.route('/buy/<int:goods_id>', methods=['POST'])
-@login_required
+@auth_views.web_user_required
 def buy(goods_id: int):
     """
     处理购买的表单
@@ -74,6 +71,8 @@ def buy(goods_id: int):
                 flash("兑换数目超出库存")
             elif res == -3:
                 flash("积分不足")
+            elif res == -5:
+                flash("用户登录冲突")
             elif res == 0:
                 flash(f"商品兑换成功, 订单: {order_id}")
             else:
@@ -82,24 +81,8 @@ def buy(goods_id: int):
     abort(404)
 
 
-def manager_required(f):
-    """
-    管理员权限
-    :return:
-    """
-
-    @functools.wraps(f)
-    def func(*args, **kwargs):
-        if not current_user.is_manager():
-            abort(403)
-        return f(*args, **kwargs)
-
-    return func
-
-
 @store.route('/set/<int:goods_id>', methods=['POST'])
-@login_required
-@manager_required
+@auth_views.manager_required
 def set_goods(goods_id: int):
     """
     设置库存
@@ -120,8 +103,7 @@ def set_goods(goods_id: int):
 
 
 @store.route('/set_score/<int:goods_id>', methods=['POST'])
-@login_required
-@manager_required
+@auth_views.manager_required
 def set_goods_score(goods_id: int):
     """
     设置兑换积分
@@ -142,8 +124,7 @@ def set_goods_score(goods_id: int):
 
 
 @store.route('/check/<string:user>/<string:order>')
-@login_required
-@manager_required
+@auth_views.manager_required
 def check(user, order):
     """
     显示取件码的获取内容
@@ -155,8 +136,7 @@ def check(user, order):
 
 
 @store.route('/confirm/<string:token>')
-@login_required
-@manager_required
+@auth_views.manager_required
 def confirm(token):
     """
     确认取件
@@ -176,8 +156,7 @@ def confirm(token):
 
 
 @store.route('/add', methods=["GET", "POST"])
-@login_required
-@manager_required
+@auth_views.manager_required
 def add_new_goods():
     """
     新增新商品

+ 5 - 3
app/web.py

@@ -6,6 +6,7 @@ from sql.store import get_store_item_list, get_store_item, confirm_order
 
 from tool.typing import *
 from tool.page import get_page
+from tool.login import create_uid
 
 from core.garbage import GarbageType
 
@@ -49,14 +50,15 @@ class AuthWebsite(WebsiteBase):
         user = find_user_by_name(name, passwd, self._db)
         if user is None:
             return None
-        return web_user.WebUser(name, uid=user.get_uid())
+        user.destruct()  # 提前释放, 后续操作与数据库无关
+        return web_user.WebUser(name, create_uid(name, passwd))
 
     def load_user_by_id(self, uid: uid_t) -> Optional["web_user.WebUser"]:
         user = find_user_by_id(uid, self._db)
         if user is None:
             return None
-        name = user.get_name()
-        return web_user.WebUser(name, uid=uid)
+        user.destruct()  # 提前释放, 后续操作与数据库无关
+        return web_user.WebUser(user.get_name(), uid)
 
     def get_user_garbage_count(self, uid: uid_t):
         return count_garbage_by_uid(uid, self._db, time_limit=False)

+ 2 - 1
app/web_goods.py

@@ -23,7 +23,7 @@ class Goods:
 
         user: User = user_.get_user()
         if user is None:
-            return -4, 0  # 系统错误
+            return -5, 0  # 系统错误
 
         try:
             score = user.get_score()
@@ -35,6 +35,7 @@ class Goods:
 
         user.add_score(-score_)
         update_user(user, views.website.db)
+        user.destruct()  # 提前释放, 后续动作与数据库无关
 
         self._quantity -= quantity
         update_goods(self._id, self._quantity, views.website.db)

+ 58 - 38
app/web_user.py

@@ -3,15 +3,19 @@ from itsdangerous import TimedJSONWebSignatureSerializer as Serializer
 
 from conf import Config
 
-from tool.login import create_uid
 from tool.typing import *
 
-from core.user import User
+from sql import DBBit
+from sql.user import search_from_user_view
+from sql.garbage import count_garbage_by_uid
+
+from core.user import User, UserType
 from . import views
 
 
 class WebAnonymous(AnonymousUserMixin):
     """ 网页匿名用户 """
+
     def __init__(self):
         self.group = "匿名用户"
         self.score = "0"
@@ -53,45 +57,65 @@ class WebAnonymous(AnonymousUserMixin):
 
 
 class WebUser(UserMixin):
-    """ 网页用户 """
-    def __init__(self, name: uname_t, passwd: passwd_t = None, uid: uid_t = None):
-        super(WebUser, self).__init__()
+    """
+    网页用户
+    代表一个网页端的登录
+    网页端只有在积分商城兑换礼品时会对数据加锁
+    """
+
+    def __init__(self, name: uname_t, uid: uid_t):
+        self._uid = uid
         self._name = name
-        if uid is None:
-            self._uid = create_uid(name, passwd)
-        else:
-            self._uid = uid
-        self.score = "0"
-        self.reputation = "0"
-        self.rubbish = "0"
-        self.group = "普通成员"
+        self._score = 0
+        self._reputation = 0
+        self._type = UserType.normal
+        self._rubbish = 0
+
+        super(WebUser, self).__init__()
         self.update_info()
 
-    def update_info(self):
-        user = views.website.get_user_by_id(self._uid)
-        if user is None:
-            return
-
-        if user.is_manager():
-            self.group = "管理员"
-            self.score = "0"
-            self.reputation = "0"
-            self.rubbish = "0"
-        else:
-            self.group = "普通成员"
-            res = user.get_info()
-            self.score = res.get('score', '0')
-            self.reputation = res.get('reputation', '0')
-            self.rubbish = res.get('rubbish', '0')
+    def update_info(self) -> bool:
+        info = search_from_user_view(columns=["Score", "Reputation", "IsManager"],
+                                     where=f"UserID='{self._uid}'",
+                                     db=views.website.db)
+        if info is None:
+            return False
+        info = info[0]
+        self._score = int(info[0])
+        self._reputation = int(info[1])
+        self._type = UserType.manager if info[2] == DBBit.BIT_1 else UserType.normal
+        self._rubbish = count_garbage_by_uid(self.uid, views.website.db)
+        if self._rubbish == -1:
+            return False
+        return True
+
+    @property
+    def score(self):
+        return f"{self._score}"
+
+    @property
+    def reputation(self):
+        return f"{self._reputation}"
+
+    @property
+    def rubbish(self):
+        return f"{self._rubbish}"
+
+    @property
+    def group(self):
+        return "管理员" if self._type == UserType.manager else "普通成员"
+
+    def is_manager(self):
+        return self._type == UserType.manager
 
     @property
     def is_active(self):
-        """Flask要求的属性"""
-        return views.website.load_user_by_id(self._uid) is not None
+        """Flask要求的属性, 表示用户是否激活(可登录), HGSSystem没有封禁用户系统, 所有用户都是被激活的"""
+        return True
 
     @property
     def is_authenticated(self):
-        """Flask要求的属性"""
+        """Flask要求的属性, 表示登录的凭据是否正确, 这里检查是否能 load_user_by_id"""
         return views.website.load_user_by_id(self._uid) is not None
 
     def get_id(self):
@@ -116,10 +140,7 @@ class WebUser(UserMixin):
         assert cur.rowcount == 1
         return str(cur.fetchone()[0])
 
-    def is_manager(self):
-        return self.group == "管理员"
-
-    def get_qr_code(self):
+    def get_order_qr_code(self):
         s = Serializer(Config.passwd_salt, expires_in=3600)  # 1h有效
         token = s.dumps({"order": f"{self.order}", "uid": f"{self._uid}"})
         return self.order, self._uid, token
@@ -147,8 +168,7 @@ class WebUser(UserMixin):
         return views.website.get_user_garbage_list(self._uid, limit=limit, offset=offset)
 
     def get_user(self) -> User:
-        res = views.website.get_user_by_id(self._uid)
-        return res
+        return views.website.get_user_by_id(self._uid)
 
     def write_news(self, text: str):
         return views.website.write_news(text, self._uid)

+ 45 - 5
core/user.py

@@ -18,11 +18,23 @@ class UserType:
 
 
 class User(metaclass=abc.ABCMeta):
-    def __init__(self, name: uname_t, uid: uid_t, user_type: enum):
+    def __init__(self, name: uname_t, uid: uid_t, user_type: enum, destruct_call: Optional[Callable]):
         self._name: uname_t = uname_t(name)
         self._uid: uid_t = uid_t(uid)
         self._type: enum = enum(user_type)
         self._lock = threading.RLock()  # 用户 互斥锁
+        self._destruct_call = destruct_call
+
+    def __del__(self):
+        if self._destruct_call is None:
+            return
+
+        _destruct_call = self._destruct_call
+        self._destruct_call = None
+        _destruct_call(self)
+
+    def destruct(self):
+        self.__del__()
 
     def is_manager(self):
         try:
@@ -79,6 +91,12 @@ class User(metaclass=abc.ABCMeta):
     def get_score(self):
         raise UserNotSupportError
 
+    def get_reputation(self):
+        raise UserNotSupportError
+
+    def get_rubbish(self):
+        raise UserNotSupportError
+
     def add_score(self, score: score_t) -> score_t:
         raise UserNotSupportError
 
@@ -90,8 +108,13 @@ class User(metaclass=abc.ABCMeta):
 
 
 class NormalUser(User):
-    def __init__(self, name: uname_t, uid: uid_t, reputation: score_t, rubbish: count_t, score: score_t):
-        super(NormalUser, self).__init__(name, uid, UserType.normal)
+    def __init__(self, name: uname_t,
+                 uid: uid_t,
+                 reputation: score_t,
+                 rubbish: count_t,
+                 score: score_t,
+                 destruct_call: Optional[Callable]):
+        super(NormalUser, self).__init__(name, uid, UserType.normal, destruct_call)
         self._reputation = score_t(reputation)
         self._rubbish = count_t(rubbish)
         self._score = score_t(score)
@@ -190,6 +213,12 @@ class NormalUser(User):
 
         return reputation
 
+    def get_reputation(self):
+        raise self._reputation
+
+    def get_rubbish(self):
+        raise self._reputation
+
     def get_score(self):
         return self._score
 
@@ -239,8 +268,10 @@ class NormalUser(User):
 
 
 class ManagerUser(User):
-    def __init__(self, name: uname_t, uid: uid_t):
-        super(ManagerUser, self).__init__(name, uid, UserType.manager)
+    def __init__(self, name: uname_t,
+                 uid: uid_t,
+                 destruct_call: Optional[Callable]):
+        super(ManagerUser, self).__init__(name, uid, UserType.manager, destruct_call)
 
     def check_rubbish(self, garbage: GarbageBag, result: bool, user: User) -> bool:
         """
@@ -283,3 +314,12 @@ class ManagerUser(User):
         finally:
             self._lock.release()
         return info
+
+    def get_reputation(self):
+        raise 0
+
+    def get_rubbish(self):
+        raise 0
+
+    def get_score(self):
+        return 0

+ 1 - 7
equipment/scan_user.py

@@ -14,19 +14,13 @@ qr_user_pattern = re.compile(r'HGSSystem-QR-USER:([a-z0-9]{32})-END', re.I)
 def scan_uid(code: QRCode) -> uid_t:
     data = code.get_data()
     res = re.match(qr_user_pattern, data)
+    print(data, res)
     if res is None:
         return ""
     else:
         return res.group(1)
 
 
-def scan_user(code: QRCode, db: DB) -> Optional[User]:
-    uid = scan_uid(code)
-    if len(uid) == 0:
-        return None
-    return find_user_by_id(uid, db)
-
-
 def __get_uid_qr_file_name(uid: uid_t, name: str, path: str, name_type="nu"):
     if name_type == "nu":
         path = os.path.join(path, f"{name}-f{uid}.png")

+ 2 - 1
init.sql

@@ -11,7 +11,8 @@ CREATE TABLE IF NOT EXISTS user -- 创建用户表
     Phone      CHAR(11)    NOT NULL CHECK (Phone REGEXP '[0-9]{11}'),
     Score      INT         NOT NULL CHECK (Score <= 500 and Score >= 0),
     Reputation INT         NOT NULL CHECK (Reputation <= 1000 and Reputation >= 1),
-    CreateTime DATETIME    NOT NULL DEFAULT CURRENT_TIMESTAMP
+    CreateTime DATETIME    NOT NULL DEFAULT CURRENT_TIMESTAMP,
+    UserLock   BIT         NOT NULL DEFAULT 0
 );
 
 CREATE TABLE IF NOT EXISTS garbage -- 创建普通垃圾表

+ 9 - 2
sql/mysql_db.py

@@ -51,7 +51,8 @@ class MysqlDB(HGSDatabase):
                limit: Optional[int] = None,
                offset: Optional[int] = None,
                order_by: Optional[List[Tuple[str, str]]] = None,
-               group_by: Optional[List[str]] = None):
+               group_by: Optional[List[str]] = None,
+               for_update: bool = False):
         if type(where) is list and len(where) > 0:
             where: str = " WHERE " + " AND ".join(f"({w})" for w in where)
         elif type(where) is str and len(where) > 0:
@@ -81,7 +82,13 @@ class MysqlDB(HGSDatabase):
             group_by = "GROUP BY " + ", ".join(group_by)
 
         columns: str = ", ".join(columns)
-        return self.__search(f"SELECT {columns} FROM {table} {where} {group_by} {order_by} {limit} {offset};")
+        if for_update:
+            for_update = "FOR UPDATE"
+        else:
+            for_update = ""
+        return self.__search(f"SELECT {columns} "
+                             f"FROM {table} "
+                             f"{where} {group_by} {order_by} {limit} {offset} {for_update};")
 
     def insert(self, table: str, columns: list, values: Union[str, List[str]], not_commit: bool = False):
         columns: str = ", ".join(columns)

+ 31 - 8
sql/user.py

@@ -80,26 +80,43 @@ def search_from_user_view(columns, where: str, db: DB):
 
 
 def find_user_by_id(uid: uid_t, db: DB) -> Optional[User]:
-    cur = db.search(columns=["UserID", "Name", "IsManager", "Score", "Reputation"],
+    cur = db.search(columns=["UserID", "Name", "IsManager", "Score", "Reputation", "UserLock"],
                     table="user",
                     where=f"UserID = '{uid}'")
     if cur is None or cur.rowcount == 0:
         return None
     assert cur.rowcount == 1
     res = cur.fetchone()
-    assert len(res) == 5
+    assert len(res) == 6
 
     uid: uid_t = res[0]
     name: uname_t = str(res[1])
     manager: bool = res[2] == DBBit.BIT_1
+    lock: bool = res[5] == DBBit.BIT_1
+
+    if lock:
+        db.commit()
+        return None
+    else:
+        cur = db.update(table="user",
+                        kw={"UserLock": "1"},
+                        where=f"UserID = '{uid}'")
+        if cur is None or cur.rowcount == 0:
+            db.commit()
+            return None
+
+    def user_destruct(*args, **kwargs):
+        db.update(table="user",
+                  kw={"UserLock": "0"},
+                  where=f"UserID = '{uid}'")
 
     if manager:
-        return ManagerUser(name, uid)
+        return ManagerUser(name, uid, user_destruct)
     else:
         score: score_t = res[3]
         reputation: score_t = res[4]
         rubbish: count_t = garbage.count_garbage_by_uid(uid, db)
-        return NormalUser(name, uid, reputation, rubbish, score)  # rubbish 实际计算
+        return NormalUser(name, uid, reputation, rubbish, score, user_destruct)  # rubbish 实际计算
 
 
 def find_user_by_name(name: uname_t, passwd: passwd_t, db: DB) -> Optional[User]:
@@ -155,14 +172,20 @@ def create_new_user(name: Optional[uname_t], passwd: Optional[passwd_t], phone:
         return None
     is_manager = '1' if manager else '0'
     cur = db.insert(table="user",
-                    columns=["UserID", "Name", "IsManager", "Phone", "Score", "Reputation", "CreateTime"],
+                    columns=["UserID", "Name", "IsManager", "Phone", "Score", "Reputation", "CreateTime", "UserLock"],
                     values=f"'{uid}', '{name}', {is_manager}, '{phone}', {Config.default_score}, "
-                           f"{Config.default_reputation}, {mysql_time()}")
+                           f"{Config.default_reputation}, {mysql_time()}, 1")
     if cur is None:
         return None
+
+    def user_destruct(*args, **kwargs):
+        db.update(table="user",
+                  kw={"UserLock": "0"},
+                  where=f"UserID = '{uid}'")
+
     if is_manager:
-        return ManagerUser(name, uid)
-    return NormalUser(name, uid, Config.default_reputation, 0, Config.default_score)
+        return ManagerUser(name, uid, user_destruct)
+    return NormalUser(name, uid, Config.default_reputation, 0, Config.default_score, user_destruct)
 
 
 def get_user_phone(uid: uid_t, db: DB) -> Optional[str]:

+ 2 - 0
tk_ui/admin.py

@@ -639,6 +639,8 @@ class AdminStation(AdminStationBase):
             self._login_passwd[2].set('')
 
     def logout(self):
+        if self._admin is not None:
+            self._admin.destruct()
         super(AdminStation, self).logout()
         self.__show_login_window()
 

+ 15 - 7
tk_ui/station.py

@@ -92,6 +92,7 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
             return False
         if not self._user.is_manager() and time.time() - self._user_last_time > 20:
             self.show_msg("退出登录", "用户自动退出", show_time=3.0)
+            self._user.destruct()
             self._user = None
             raise Exception
         return True
@@ -123,17 +124,23 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
     def get_cap_img(self):
         return self._cap.get_frame()
 
+    def logout_user(self):
+        self._user.destruct()
+        self._user = None  # 退出登录
+        self._user_last_time = 0
+        self.show_msg("退出登录", "退出登录成功", show_time=3.0)
+        return False
+
     def switch_user(self, user: User) -> bool:
         """
         切换用户: 退出/登录
         :param user: 新用户
         :return: 登录-True, 退出-False
         """
-        if self._user is not None and self._user.get_uid() == user.get_uid() and self.check_user():  # 正在登陆期
-            self._user = None  # 退出登录
-            self._user_last_time = 0
-            self.show_msg("退出登录", "退出登录成功", show_time=3.0)
-            return False
+        if self._user is not None:
+            print(f"{self._user.get_uid()=} {user.get_uid()=}")
+        if self._user != user and self._user is not None:
+            self._user.destruct()
         self._user = user
         self._user_last_time = time.time()
         self.show_msg("登录", "登录成功", show_time=3.0)
@@ -485,9 +492,9 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
     def mainloop(self):
         ...
 
-    @abc.abstractmethod
     def exit_win(self):
-        ...
+        if self._user is not None:
+            self._user.destruct()
 
 
 from . import station_event as tk_event
@@ -1337,4 +1344,5 @@ class GarbageStation(GarbageStationBase):
         self._window.mainloop()
 
     def exit_win(self):
+        super(GarbageStation, self).exit_win()
         self._window.destroy()

+ 26 - 10
tk_ui/station_event.py

@@ -1,7 +1,7 @@
 import tempfile
 
 from equipment.scan import QRCode
-from equipment.scan_user import scan_user
+from equipment.scan_user import scan_uid
 from equipment.scan_garbage import scan_garbage
 
 from tool.typing import *
@@ -10,6 +10,7 @@ from core.user import User
 from core.garbage import GarbageBag
 
 from sql.db import DB
+from sql.user import find_user_by_id
 
 from .event import TkThreading, TkEventBase
 from . import station as tk_station
@@ -32,9 +33,21 @@ class ScanUserEvent(StationEventBase):
     若QR-CODE不是User码则调用ScanGarbage任务
     """
 
-    @staticmethod
-    def func(qr: QRCode, db: DB):
-        return scan_user(qr, db)
+    def func(self, qr: QRCode):
+        """
+        扫描用户
+        :param qr: 二维码
+        :return:
+            若是已登录用户再次扫码, 则返回 False, None
+            若是新登录用户扫码, 则返回 True, User
+            错误返回 None, None
+        """
+        uid = scan_uid(qr)
+        if len(uid) == 0:
+            return None, None
+        if uid == self.station.get_uid_no_update():
+            return False, None
+        return True, find_user_by_id(uid, self._db)
 
     def __init__(self, gb_station):
         super().__init__(gb_station, "扫码用户")
@@ -45,20 +58,23 @@ class ScanUserEvent(StationEventBase):
 
     def start(self, qr_code: QRCode):
         self._qr_code = qr_code
-        self.thread = TkThreading(self.func, qr_code, self._db)
+        self.thread = TkThreading(self.func, qr_code)
         return self
 
     def is_end(self) -> bool:
         return self.thread is not None and not self.thread.is_alive()
 
     def done_after_event(self):
-        self.thread.join()
-        if self.thread.result is not None:
-            self.station.switch_user(self.thread.result)
-            self.station.update_control()
-        else:
+        res, user = self.thread.wait_event()
+        if res is None:
             event = ScanGarbageEvent(self.station).start(self._qr_code)
             self.station.push_event(event)
+        if res:
+            self.station.switch_user(user)
+            self.station.update_control()
+        else:
+            self.station.logout_user()
+            self.station.update_control()
 
 
 class ScanGarbageEvent(StationEventBase):