1
0
SongZihuan 3 жил өмнө
parent
commit
6a7d96f78f

+ 40 - 8
app/auth/views.py

@@ -1,5 +1,5 @@
 import math
 import math
-
+import functools
 from flask import render_template, Blueprint, Flask, request, url_for, redirect, flash, abort
 from flask import render_template, Blueprint, Flask, request, url_for, redirect, flash, abort
 from wtforms import StringField, PasswordField, SubmitField
 from wtforms import StringField, PasswordField, SubmitField
 from wtforms.validators import DataRequired
 from wtforms.validators import DataRequired
@@ -34,6 +34,40 @@ class LoginForm(FlaskForm):
     submit = SubmitField("登录")
     submit = SubmitField("登录")
 
 
 
 
+def web_user_required(f):
+    @login_required
+    @functools.wraps(f)
+    def func(*args, **kwargs):
+        if not current_user.update_info():
+            logout()
+            flash("用户错误")
+            abort(403)
+        return f(*args, **kwargs)
+
+    return func
+
+
+def manager_required(f):
+    """
+    管理员权限
+    :return:
+    """
+
+    @login_required
+    @functools.wraps(f)
+    def func(*args, **kwargs):
+        if not current_user.update_info():
+            logout()
+            flash("用户错误")
+            abort(403)
+
+        if not current_user.is_manager():
+            abort(403)
+        return f(*args, **kwargs)
+
+    return func
+
+
 @auth.route('/login', methods=['GET', 'POST'])
 @auth.route('/login', methods=['GET', 'POST'])
 def login():
 def login():
     form = LoginForm()
     form = LoginForm()
@@ -55,7 +89,7 @@ def login():
 
 
 
 
 @auth.route("/logout")
 @auth.route("/logout")
-@login_required
+@web_user_required
 def logout():
 def logout():
     logout_user()
     logout_user()
     flash("用户退出成功")
     flash("用户退出成功")
@@ -63,10 +97,9 @@ def logout():
 
 
 
 
 @auth.route("/about")
 @auth.route("/about")
-@login_required
+@web_user_required
 def about():
 def about():
     user: web_user.WebUser = current_user
     user: web_user.WebUser = current_user
-    user.update_info()
 
 
     try:
     try:
         page = int(request.args.get("page", 1))  # page 指垃圾袋的分页信息
         page = int(request.args.get("page", 1))  # page 指垃圾袋的分页信息
@@ -82,17 +115,16 @@ def about():
 
 
 
 
 @auth.route("/order")
 @auth.route("/order")
-@login_required
+@web_user_required
 def order_qr():
 def order_qr():
     """
     """
     生成取件码和确认码
     生成取件码和确认码
     图像临时保存在 BytesIO 中
     图像临时保存在 BytesIO 中
     然后转换为Base64显示
     然后转换为Base64显示
     """
     """
-    user: web_user.WebUser = current_user
-    user.update_info()
 
 
-    order, user, token = user.get_qr_code()  # 订单号, 用户ID, 确认码
+    user: web_user.WebUser = current_user
+    order, user, token = user.get_order_qr_code()  # 订单号, 用户ID, 确认码
 
 
     check_image = qrcode.make(data=url_for("store.check", user=user, order=order, _external=True))
     check_image = qrcode.make(data=url_for("store.check", user=user, order=order, _external=True))
     check_img_buffer = BytesIO()
     check_img_buffer = BytesIO()

+ 3 - 17
app/news/views.py

@@ -9,6 +9,7 @@ from tool.typing import Optional
 
 
 from app import views
 from app import views
 from app.web_user import WebUser
 from app.web_user import WebUser
+from app.auth import views as auth_views
 
 
 news = Blueprint("news", __name__)
 news = Blueprint("news", __name__)
 app: Optional[Flask] = None
 app: Optional[Flask] = None
@@ -30,7 +31,7 @@ class NewDelete(FlaskForm):
 
 
 
 
 @news.route('/', methods=['GET', 'POST'])
 @news.route('/', methods=['GET', 'POST'])
-@login_required
+@auth_views.web_user_required
 def index():
 def index():
     """
     """
     Get请求时: 显示(获取)新闻消息
     Get请求时: 显示(获取)新闻消息
@@ -56,23 +57,8 @@ def index():
                            page_list=page_list, page=f"{page}", news_delete=delete_form)
                            page_list=page_list, page=f"{page}", news_delete=delete_form)
 
 
 
 
-def manager_required(f):
-    """
-    检查是否有管理员权限
-    """
-
-    @functools.wraps(f)
-    def func(*args, **kwargs):
-        if not current_user.is_manager():
-            abort(403)
-        return f(*args, **kwargs)
-
-    return func
-
-
 @news.route('/delete', methods=['POST'])
 @news.route('/delete', methods=['POST'])
-@login_required
-@manager_required
+@auth_views.manager_required
 def delete():
 def delete():
     """
     """
     管理员: 删除内容
     管理员: 删除内容

+ 11 - 32
app/store/views.py

@@ -3,14 +3,13 @@ from wtforms import TextField, SubmitField
 from flask_login import current_user
 from flask_login import current_user
 from wtforms.validators import DataRequired
 from wtforms.validators import DataRequired
 from flask_wtf import FlaskForm
 from flask_wtf import FlaskForm
-from flask_login import login_required
-import functools
 from itsdangerous import TimedJSONWebSignatureSerializer as Serializer
 from itsdangerous import TimedJSONWebSignatureSerializer as Serializer
 
 
 from conf import Config
 from conf import Config
 from tool.typing import Optional
 from tool.typing import Optional
 from app import views
 from app import views
-from app import web_user
+
+from app.auth import views as auth_views
 
 
 store = Blueprint("store", __name__)
 store = Blueprint("store", __name__)
 app: Optional[Flask] = None
 app: Optional[Flask] = None
@@ -35,7 +34,7 @@ class AddNewGoodsForm(FlaskForm):
 
 
 
 
 @store.route('/', methods=['GET', 'POST'])
 @store.route('/', methods=['GET', 'POST'])
-@login_required
+@auth_views.web_user_required
 def index():
 def index():
     """
     """
     显示购买的表单
     显示购买的表单
@@ -44,13 +43,11 @@ def index():
     """
     """
     form = BuySetForm()
     form = BuySetForm()
     store_list = views.website.get_store_list()
     store_list = views.website.get_store_list()
-    user: web_user.WebUser = current_user
-    user.update_info()
     return render_template("store/store.html", store_list=store_list, store_form=form)
     return render_template("store/store.html", store_list=store_list, store_form=form)
 
 
 
 
 @store.route('/buy/<int:goods_id>', methods=['POST'])
 @store.route('/buy/<int:goods_id>', methods=['POST'])
-@login_required
+@auth_views.web_user_required
 def buy(goods_id: int):
 def buy(goods_id: int):
     """
     """
     处理购买的表单
     处理购买的表单
@@ -74,6 +71,8 @@ def buy(goods_id: int):
                 flash("兑换数目超出库存")
                 flash("兑换数目超出库存")
             elif res == -3:
             elif res == -3:
                 flash("积分不足")
                 flash("积分不足")
+            elif res == -5:
+                flash("用户登录冲突")
             elif res == 0:
             elif res == 0:
                 flash(f"商品兑换成功, 订单: {order_id}")
                 flash(f"商品兑换成功, 订单: {order_id}")
             else:
             else:
@@ -82,24 +81,8 @@ def buy(goods_id: int):
     abort(404)
     abort(404)
 
 
 
 
-def manager_required(f):
-    """
-    管理员权限
-    :return:
-    """
-
-    @functools.wraps(f)
-    def func(*args, **kwargs):
-        if not current_user.is_manager():
-            abort(403)
-        return f(*args, **kwargs)
-
-    return func
-
-
 @store.route('/set/<int:goods_id>', methods=['POST'])
 @store.route('/set/<int:goods_id>', methods=['POST'])
-@login_required
-@manager_required
+@auth_views.manager_required
 def set_goods(goods_id: int):
 def set_goods(goods_id: int):
     """
     """
     设置库存
     设置库存
@@ -120,8 +103,7 @@ def set_goods(goods_id: int):
 
 
 
 
 @store.route('/set_score/<int:goods_id>', methods=['POST'])
 @store.route('/set_score/<int:goods_id>', methods=['POST'])
-@login_required
-@manager_required
+@auth_views.manager_required
 def set_goods_score(goods_id: int):
 def set_goods_score(goods_id: int):
     """
     """
     设置兑换积分
     设置兑换积分
@@ -142,8 +124,7 @@ def set_goods_score(goods_id: int):
 
 
 
 
 @store.route('/check/<string:user>/<string:order>')
 @store.route('/check/<string:user>/<string:order>')
-@login_required
-@manager_required
+@auth_views.manager_required
 def check(user, order):
 def check(user, order):
     """
     """
     显示取件码的获取内容
     显示取件码的获取内容
@@ -155,8 +136,7 @@ def check(user, order):
 
 
 
 
 @store.route('/confirm/<string:token>')
 @store.route('/confirm/<string:token>')
-@login_required
-@manager_required
+@auth_views.manager_required
 def confirm(token):
 def confirm(token):
     """
     """
     确认取件
     确认取件
@@ -176,8 +156,7 @@ def confirm(token):
 
 
 
 
 @store.route('/add', methods=["GET", "POST"])
 @store.route('/add', methods=["GET", "POST"])
-@login_required
-@manager_required
+@auth_views.manager_required
 def add_new_goods():
 def add_new_goods():
     """
     """
     新增新商品
     新增新商品

+ 5 - 3
app/web.py

@@ -6,6 +6,7 @@ from sql.store import get_store_item_list, get_store_item, confirm_order
 
 
 from tool.typing import *
 from tool.typing import *
 from tool.page import get_page
 from tool.page import get_page
+from tool.login import create_uid
 
 
 from core.garbage import GarbageType
 from core.garbage import GarbageType
 
 
@@ -49,14 +50,15 @@ class AuthWebsite(WebsiteBase):
         user = find_user_by_name(name, passwd, self._db)
         user = find_user_by_name(name, passwd, self._db)
         if user is None:
         if user is None:
             return None
             return None
-        return web_user.WebUser(name, uid=user.get_uid())
+        user.destruct()  # 提前释放, 后续操作与数据库无关
+        return web_user.WebUser(name, create_uid(name, passwd))
 
 
     def load_user_by_id(self, uid: uid_t) -> Optional["web_user.WebUser"]:
     def load_user_by_id(self, uid: uid_t) -> Optional["web_user.WebUser"]:
         user = find_user_by_id(uid, self._db)
         user = find_user_by_id(uid, self._db)
         if user is None:
         if user is None:
             return None
             return None
-        name = user.get_name()
-        return web_user.WebUser(name, uid=uid)
+        user.destruct()  # 提前释放, 后续操作与数据库无关
+        return web_user.WebUser(user.get_name(), uid)
 
 
     def get_user_garbage_count(self, uid: uid_t):
     def get_user_garbage_count(self, uid: uid_t):
         return count_garbage_by_uid(uid, self._db, time_limit=False)
         return count_garbage_by_uid(uid, self._db, time_limit=False)

+ 2 - 1
app/web_goods.py

@@ -23,7 +23,7 @@ class Goods:
 
 
         user: User = user_.get_user()
         user: User = user_.get_user()
         if user is None:
         if user is None:
-            return -4, 0  # 系统错误
+            return -5, 0  # 系统错误
 
 
         try:
         try:
             score = user.get_score()
             score = user.get_score()
@@ -35,6 +35,7 @@ class Goods:
 
 
         user.add_score(-score_)
         user.add_score(-score_)
         update_user(user, views.website.db)
         update_user(user, views.website.db)
+        user.destruct()  # 提前释放, 后续动作与数据库无关
 
 
         self._quantity -= quantity
         self._quantity -= quantity
         update_goods(self._id, self._quantity, views.website.db)
         update_goods(self._id, self._quantity, views.website.db)

+ 58 - 38
app/web_user.py

@@ -3,15 +3,19 @@ from itsdangerous import TimedJSONWebSignatureSerializer as Serializer
 
 
 from conf import Config
 from conf import Config
 
 
-from tool.login import create_uid
 from tool.typing import *
 from tool.typing import *
 
 
-from core.user import User
+from sql import DBBit
+from sql.user import search_from_user_view
+from sql.garbage import count_garbage_by_uid
+
+from core.user import User, UserType
 from . import views
 from . import views
 
 
 
 
 class WebAnonymous(AnonymousUserMixin):
 class WebAnonymous(AnonymousUserMixin):
     """ 网页匿名用户 """
     """ 网页匿名用户 """
+
     def __init__(self):
     def __init__(self):
         self.group = "匿名用户"
         self.group = "匿名用户"
         self.score = "0"
         self.score = "0"
@@ -53,45 +57,65 @@ class WebAnonymous(AnonymousUserMixin):
 
 
 
 
 class WebUser(UserMixin):
 class WebUser(UserMixin):
-    """ 网页用户 """
-    def __init__(self, name: uname_t, passwd: passwd_t = None, uid: uid_t = None):
-        super(WebUser, self).__init__()
+    """
+    网页用户
+    代表一个网页端的登录
+    网页端只有在积分商城兑换礼品时会对数据加锁
+    """
+
+    def __init__(self, name: uname_t, uid: uid_t):
+        self._uid = uid
         self._name = name
         self._name = name
-        if uid is None:
-            self._uid = create_uid(name, passwd)
-        else:
-            self._uid = uid
-        self.score = "0"
-        self.reputation = "0"
-        self.rubbish = "0"
-        self.group = "普通成员"
+        self._score = 0
+        self._reputation = 0
+        self._type = UserType.normal
+        self._rubbish = 0
+
+        super(WebUser, self).__init__()
         self.update_info()
         self.update_info()
 
 
-    def update_info(self):
-        user = views.website.get_user_by_id(self._uid)
-        if user is None:
-            return
-
-        if user.is_manager():
-            self.group = "管理员"
-            self.score = "0"
-            self.reputation = "0"
-            self.rubbish = "0"
-        else:
-            self.group = "普通成员"
-            res = user.get_info()
-            self.score = res.get('score', '0')
-            self.reputation = res.get('reputation', '0')
-            self.rubbish = res.get('rubbish', '0')
+    def update_info(self) -> bool:
+        info = search_from_user_view(columns=["Score", "Reputation", "IsManager"],
+                                     where=f"UserID='{self._uid}'",
+                                     db=views.website.db)
+        if info is None:
+            return False
+        info = info[0]
+        self._score = int(info[0])
+        self._reputation = int(info[1])
+        self._type = UserType.manager if info[2] == DBBit.BIT_1 else UserType.normal
+        self._rubbish = count_garbage_by_uid(self.uid, views.website.db)
+        if self._rubbish == -1:
+            return False
+        return True
+
+    @property
+    def score(self):
+        return f"{self._score}"
+
+    @property
+    def reputation(self):
+        return f"{self._reputation}"
+
+    @property
+    def rubbish(self):
+        return f"{self._rubbish}"
+
+    @property
+    def group(self):
+        return "管理员" if self._type == UserType.manager else "普通成员"
+
+    def is_manager(self):
+        return self._type == UserType.manager
 
 
     @property
     @property
     def is_active(self):
     def is_active(self):
-        """Flask要求的属性"""
-        return views.website.load_user_by_id(self._uid) is not None
+        """Flask要求的属性, 表示用户是否激活(可登录), HGSSystem没有封禁用户系统, 所有用户都是被激活的"""
+        return True
 
 
     @property
     @property
     def is_authenticated(self):
     def is_authenticated(self):
-        """Flask要求的属性"""
+        """Flask要求的属性, 表示登录的凭据是否正确, 这里检查是否能 load_user_by_id"""
         return views.website.load_user_by_id(self._uid) is not None
         return views.website.load_user_by_id(self._uid) is not None
 
 
     def get_id(self):
     def get_id(self):
@@ -116,10 +140,7 @@ class WebUser(UserMixin):
         assert cur.rowcount == 1
         assert cur.rowcount == 1
         return str(cur.fetchone()[0])
         return str(cur.fetchone()[0])
 
 
-    def is_manager(self):
-        return self.group == "管理员"
-
-    def get_qr_code(self):
+    def get_order_qr_code(self):
         s = Serializer(Config.passwd_salt, expires_in=3600)  # 1h有效
         s = Serializer(Config.passwd_salt, expires_in=3600)  # 1h有效
         token = s.dumps({"order": f"{self.order}", "uid": f"{self._uid}"})
         token = s.dumps({"order": f"{self.order}", "uid": f"{self._uid}"})
         return self.order, self._uid, token
         return self.order, self._uid, token
@@ -147,8 +168,7 @@ class WebUser(UserMixin):
         return views.website.get_user_garbage_list(self._uid, limit=limit, offset=offset)
         return views.website.get_user_garbage_list(self._uid, limit=limit, offset=offset)
 
 
     def get_user(self) -> User:
     def get_user(self) -> User:
-        res = views.website.get_user_by_id(self._uid)
-        return res
+        return views.website.get_user_by_id(self._uid)
 
 
     def write_news(self, text: str):
     def write_news(self, text: str):
         return views.website.write_news(text, self._uid)
         return views.website.write_news(text, self._uid)

+ 45 - 5
core/user.py

@@ -18,11 +18,23 @@ class UserType:
 
 
 
 
 class User(metaclass=abc.ABCMeta):
 class User(metaclass=abc.ABCMeta):
-    def __init__(self, name: uname_t, uid: uid_t, user_type: enum):
+    def __init__(self, name: uname_t, uid: uid_t, user_type: enum, destruct_call: Optional[Callable]):
         self._name: uname_t = uname_t(name)
         self._name: uname_t = uname_t(name)
         self._uid: uid_t = uid_t(uid)
         self._uid: uid_t = uid_t(uid)
         self._type: enum = enum(user_type)
         self._type: enum = enum(user_type)
         self._lock = threading.RLock()  # 用户 互斥锁
         self._lock = threading.RLock()  # 用户 互斥锁
+        self._destruct_call = destruct_call
+
+    def __del__(self):
+        if self._destruct_call is None:
+            return
+
+        _destruct_call = self._destruct_call
+        self._destruct_call = None
+        _destruct_call(self)
+
+    def destruct(self):
+        self.__del__()
 
 
     def is_manager(self):
     def is_manager(self):
         try:
         try:
@@ -79,6 +91,12 @@ class User(metaclass=abc.ABCMeta):
     def get_score(self):
     def get_score(self):
         raise UserNotSupportError
         raise UserNotSupportError
 
 
+    def get_reputation(self):
+        raise UserNotSupportError
+
+    def get_rubbish(self):
+        raise UserNotSupportError
+
     def add_score(self, score: score_t) -> score_t:
     def add_score(self, score: score_t) -> score_t:
         raise UserNotSupportError
         raise UserNotSupportError
 
 
@@ -90,8 +108,13 @@ class User(metaclass=abc.ABCMeta):
 
 
 
 
 class NormalUser(User):
 class NormalUser(User):
-    def __init__(self, name: uname_t, uid: uid_t, reputation: score_t, rubbish: count_t, score: score_t):
-        super(NormalUser, self).__init__(name, uid, UserType.normal)
+    def __init__(self, name: uname_t,
+                 uid: uid_t,
+                 reputation: score_t,
+                 rubbish: count_t,
+                 score: score_t,
+                 destruct_call: Optional[Callable]):
+        super(NormalUser, self).__init__(name, uid, UserType.normal, destruct_call)
         self._reputation = score_t(reputation)
         self._reputation = score_t(reputation)
         self._rubbish = count_t(rubbish)
         self._rubbish = count_t(rubbish)
         self._score = score_t(score)
         self._score = score_t(score)
@@ -190,6 +213,12 @@ class NormalUser(User):
 
 
         return reputation
         return reputation
 
 
+    def get_reputation(self):
+        raise self._reputation
+
+    def get_rubbish(self):
+        raise self._reputation
+
     def get_score(self):
     def get_score(self):
         return self._score
         return self._score
 
 
@@ -239,8 +268,10 @@ class NormalUser(User):
 
 
 
 
 class ManagerUser(User):
 class ManagerUser(User):
-    def __init__(self, name: uname_t, uid: uid_t):
-        super(ManagerUser, self).__init__(name, uid, UserType.manager)
+    def __init__(self, name: uname_t,
+                 uid: uid_t,
+                 destruct_call: Optional[Callable]):
+        super(ManagerUser, self).__init__(name, uid, UserType.manager, destruct_call)
 
 
     def check_rubbish(self, garbage: GarbageBag, result: bool, user: User) -> bool:
     def check_rubbish(self, garbage: GarbageBag, result: bool, user: User) -> bool:
         """
         """
@@ -283,3 +314,12 @@ class ManagerUser(User):
         finally:
         finally:
             self._lock.release()
             self._lock.release()
         return info
         return info
+
+    def get_reputation(self):
+        raise 0
+
+    def get_rubbish(self):
+        raise 0
+
+    def get_score(self):
+        return 0

+ 1 - 7
equipment/scan_user.py

@@ -14,19 +14,13 @@ qr_user_pattern = re.compile(r'HGSSystem-QR-USER:([a-z0-9]{32})-END', re.I)
 def scan_uid(code: QRCode) -> uid_t:
 def scan_uid(code: QRCode) -> uid_t:
     data = code.get_data()
     data = code.get_data()
     res = re.match(qr_user_pattern, data)
     res = re.match(qr_user_pattern, data)
+    print(data, res)
     if res is None:
     if res is None:
         return ""
         return ""
     else:
     else:
         return res.group(1)
         return res.group(1)
 
 
 
 
-def scan_user(code: QRCode, db: DB) -> Optional[User]:
-    uid = scan_uid(code)
-    if len(uid) == 0:
-        return None
-    return find_user_by_id(uid, db)
-
-
 def __get_uid_qr_file_name(uid: uid_t, name: str, path: str, name_type="nu"):
 def __get_uid_qr_file_name(uid: uid_t, name: str, path: str, name_type="nu"):
     if name_type == "nu":
     if name_type == "nu":
         path = os.path.join(path, f"{name}-f{uid}.png")
         path = os.path.join(path, f"{name}-f{uid}.png")

+ 2 - 1
init.sql

@@ -11,7 +11,8 @@ CREATE TABLE IF NOT EXISTS user -- 创建用户表
     Phone      CHAR(11)    NOT NULL CHECK (Phone REGEXP '[0-9]{11}'),
     Phone      CHAR(11)    NOT NULL CHECK (Phone REGEXP '[0-9]{11}'),
     Score      INT         NOT NULL CHECK (Score <= 500 and Score >= 0),
     Score      INT         NOT NULL CHECK (Score <= 500 and Score >= 0),
     Reputation INT         NOT NULL CHECK (Reputation <= 1000 and Reputation >= 1),
     Reputation INT         NOT NULL CHECK (Reputation <= 1000 and Reputation >= 1),
-    CreateTime DATETIME    NOT NULL DEFAULT CURRENT_TIMESTAMP
+    CreateTime DATETIME    NOT NULL DEFAULT CURRENT_TIMESTAMP,
+    UserLock   BIT         NOT NULL DEFAULT 0
 );
 );
 
 
 CREATE TABLE IF NOT EXISTS garbage -- 创建普通垃圾表
 CREATE TABLE IF NOT EXISTS garbage -- 创建普通垃圾表

+ 9 - 2
sql/mysql_db.py

@@ -51,7 +51,8 @@ class MysqlDB(HGSDatabase):
                limit: Optional[int] = None,
                limit: Optional[int] = None,
                offset: Optional[int] = None,
                offset: Optional[int] = None,
                order_by: Optional[List[Tuple[str, str]]] = None,
                order_by: Optional[List[Tuple[str, str]]] = None,
-               group_by: Optional[List[str]] = None):
+               group_by: Optional[List[str]] = None,
+               for_update: bool = False):
         if type(where) is list and len(where) > 0:
         if type(where) is list and len(where) > 0:
             where: str = " WHERE " + " AND ".join(f"({w})" for w in where)
             where: str = " WHERE " + " AND ".join(f"({w})" for w in where)
         elif type(where) is str and len(where) > 0:
         elif type(where) is str and len(where) > 0:
@@ -81,7 +82,13 @@ class MysqlDB(HGSDatabase):
             group_by = "GROUP BY " + ", ".join(group_by)
             group_by = "GROUP BY " + ", ".join(group_by)
 
 
         columns: str = ", ".join(columns)
         columns: str = ", ".join(columns)
-        return self.__search(f"SELECT {columns} FROM {table} {where} {group_by} {order_by} {limit} {offset};")
+        if for_update:
+            for_update = "FOR UPDATE"
+        else:
+            for_update = ""
+        return self.__search(f"SELECT {columns} "
+                             f"FROM {table} "
+                             f"{where} {group_by} {order_by} {limit} {offset} {for_update};")
 
 
     def insert(self, table: str, columns: list, values: Union[str, List[str]], not_commit: bool = False):
     def insert(self, table: str, columns: list, values: Union[str, List[str]], not_commit: bool = False):
         columns: str = ", ".join(columns)
         columns: str = ", ".join(columns)

+ 31 - 8
sql/user.py

@@ -80,26 +80,43 @@ def search_from_user_view(columns, where: str, db: DB):
 
 
 
 
 def find_user_by_id(uid: uid_t, db: DB) -> Optional[User]:
 def find_user_by_id(uid: uid_t, db: DB) -> Optional[User]:
-    cur = db.search(columns=["UserID", "Name", "IsManager", "Score", "Reputation"],
+    cur = db.search(columns=["UserID", "Name", "IsManager", "Score", "Reputation", "UserLock"],
                     table="user",
                     table="user",
                     where=f"UserID = '{uid}'")
                     where=f"UserID = '{uid}'")
     if cur is None or cur.rowcount == 0:
     if cur is None or cur.rowcount == 0:
         return None
         return None
     assert cur.rowcount == 1
     assert cur.rowcount == 1
     res = cur.fetchone()
     res = cur.fetchone()
-    assert len(res) == 5
+    assert len(res) == 6
 
 
     uid: uid_t = res[0]
     uid: uid_t = res[0]
     name: uname_t = str(res[1])
     name: uname_t = str(res[1])
     manager: bool = res[2] == DBBit.BIT_1
     manager: bool = res[2] == DBBit.BIT_1
+    lock: bool = res[5] == DBBit.BIT_1
+
+    if lock:
+        db.commit()
+        return None
+    else:
+        cur = db.update(table="user",
+                        kw={"UserLock": "1"},
+                        where=f"UserID = '{uid}'")
+        if cur is None or cur.rowcount == 0:
+            db.commit()
+            return None
+
+    def user_destruct(*args, **kwargs):
+        db.update(table="user",
+                  kw={"UserLock": "0"},
+                  where=f"UserID = '{uid}'")
 
 
     if manager:
     if manager:
-        return ManagerUser(name, uid)
+        return ManagerUser(name, uid, user_destruct)
     else:
     else:
         score: score_t = res[3]
         score: score_t = res[3]
         reputation: score_t = res[4]
         reputation: score_t = res[4]
         rubbish: count_t = garbage.count_garbage_by_uid(uid, db)
         rubbish: count_t = garbage.count_garbage_by_uid(uid, db)
-        return NormalUser(name, uid, reputation, rubbish, score)  # rubbish 实际计算
+        return NormalUser(name, uid, reputation, rubbish, score, user_destruct)  # rubbish 实际计算
 
 
 
 
 def find_user_by_name(name: uname_t, passwd: passwd_t, db: DB) -> Optional[User]:
 def find_user_by_name(name: uname_t, passwd: passwd_t, db: DB) -> Optional[User]:
@@ -155,14 +172,20 @@ def create_new_user(name: Optional[uname_t], passwd: Optional[passwd_t], phone:
         return None
         return None
     is_manager = '1' if manager else '0'
     is_manager = '1' if manager else '0'
     cur = db.insert(table="user",
     cur = db.insert(table="user",
-                    columns=["UserID", "Name", "IsManager", "Phone", "Score", "Reputation", "CreateTime"],
+                    columns=["UserID", "Name", "IsManager", "Phone", "Score", "Reputation", "CreateTime", "UserLock"],
                     values=f"'{uid}', '{name}', {is_manager}, '{phone}', {Config.default_score}, "
                     values=f"'{uid}', '{name}', {is_manager}, '{phone}', {Config.default_score}, "
-                           f"{Config.default_reputation}, {mysql_time()}")
+                           f"{Config.default_reputation}, {mysql_time()}, 1")
     if cur is None:
     if cur is None:
         return None
         return None
+
+    def user_destruct(*args, **kwargs):
+        db.update(table="user",
+                  kw={"UserLock": "0"},
+                  where=f"UserID = '{uid}'")
+
     if is_manager:
     if is_manager:
-        return ManagerUser(name, uid)
-    return NormalUser(name, uid, Config.default_reputation, 0, Config.default_score)
+        return ManagerUser(name, uid, user_destruct)
+    return NormalUser(name, uid, Config.default_reputation, 0, Config.default_score, user_destruct)
 
 
 
 
 def get_user_phone(uid: uid_t, db: DB) -> Optional[str]:
 def get_user_phone(uid: uid_t, db: DB) -> Optional[str]:

+ 2 - 0
tk_ui/admin.py

@@ -639,6 +639,8 @@ class AdminStation(AdminStationBase):
             self._login_passwd[2].set('')
             self._login_passwd[2].set('')
 
 
     def logout(self):
     def logout(self):
+        if self._admin is not None:
+            self._admin.destruct()
         super(AdminStation, self).logout()
         super(AdminStation, self).logout()
         self.__show_login_window()
         self.__show_login_window()
 
 

+ 15 - 7
tk_ui/station.py

@@ -92,6 +92,7 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
             return False
             return False
         if not self._user.is_manager() and time.time() - self._user_last_time > 20:
         if not self._user.is_manager() and time.time() - self._user_last_time > 20:
             self.show_msg("退出登录", "用户自动退出", show_time=3.0)
             self.show_msg("退出登录", "用户自动退出", show_time=3.0)
+            self._user.destruct()
             self._user = None
             self._user = None
             raise Exception
             raise Exception
         return True
         return True
@@ -123,17 +124,23 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
     def get_cap_img(self):
     def get_cap_img(self):
         return self._cap.get_frame()
         return self._cap.get_frame()
 
 
+    def logout_user(self):
+        self._user.destruct()
+        self._user = None  # 退出登录
+        self._user_last_time = 0
+        self.show_msg("退出登录", "退出登录成功", show_time=3.0)
+        return False
+
     def switch_user(self, user: User) -> bool:
     def switch_user(self, user: User) -> bool:
         """
         """
         切换用户: 退出/登录
         切换用户: 退出/登录
         :param user: 新用户
         :param user: 新用户
         :return: 登录-True, 退出-False
         :return: 登录-True, 退出-False
         """
         """
-        if self._user is not None and self._user.get_uid() == user.get_uid() and self.check_user():  # 正在登陆期
-            self._user = None  # 退出登录
-            self._user_last_time = 0
-            self.show_msg("退出登录", "退出登录成功", show_time=3.0)
-            return False
+        if self._user is not None:
+            print(f"{self._user.get_uid()=} {user.get_uid()=}")
+        if self._user != user and self._user is not None:
+            self._user.destruct()
         self._user = user
         self._user = user
         self._user_last_time = time.time()
         self._user_last_time = time.time()
         self.show_msg("登录", "登录成功", show_time=3.0)
         self.show_msg("登录", "登录成功", show_time=3.0)
@@ -485,9 +492,9 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
     def mainloop(self):
     def mainloop(self):
         ...
         ...
 
 
-    @abc.abstractmethod
     def exit_win(self):
     def exit_win(self):
-        ...
+        if self._user is not None:
+            self._user.destruct()
 
 
 
 
 from . import station_event as tk_event
 from . import station_event as tk_event
@@ -1337,4 +1344,5 @@ class GarbageStation(GarbageStationBase):
         self._window.mainloop()
         self._window.mainloop()
 
 
     def exit_win(self):
     def exit_win(self):
+        super(GarbageStation, self).exit_win()
         self._window.destroy()
         self._window.destroy()

+ 26 - 10
tk_ui/station_event.py

@@ -1,7 +1,7 @@
 import tempfile
 import tempfile
 
 
 from equipment.scan import QRCode
 from equipment.scan import QRCode
-from equipment.scan_user import scan_user
+from equipment.scan_user import scan_uid
 from equipment.scan_garbage import scan_garbage
 from equipment.scan_garbage import scan_garbage
 
 
 from tool.typing import *
 from tool.typing import *
@@ -10,6 +10,7 @@ from core.user import User
 from core.garbage import GarbageBag
 from core.garbage import GarbageBag
 
 
 from sql.db import DB
 from sql.db import DB
+from sql.user import find_user_by_id
 
 
 from .event import TkThreading, TkEventBase
 from .event import TkThreading, TkEventBase
 from . import station as tk_station
 from . import station as tk_station
@@ -32,9 +33,21 @@ class ScanUserEvent(StationEventBase):
     若QR-CODE不是User码则调用ScanGarbage任务
     若QR-CODE不是User码则调用ScanGarbage任务
     """
     """
 
 
-    @staticmethod
-    def func(qr: QRCode, db: DB):
-        return scan_user(qr, db)
+    def func(self, qr: QRCode):
+        """
+        扫描用户
+        :param qr: 二维码
+        :return:
+            若是已登录用户再次扫码, 则返回 False, None
+            若是新登录用户扫码, 则返回 True, User
+            错误返回 None, None
+        """
+        uid = scan_uid(qr)
+        if len(uid) == 0:
+            return None, None
+        if uid == self.station.get_uid_no_update():
+            return False, None
+        return True, find_user_by_id(uid, self._db)
 
 
     def __init__(self, gb_station):
     def __init__(self, gb_station):
         super().__init__(gb_station, "扫码用户")
         super().__init__(gb_station, "扫码用户")
@@ -45,20 +58,23 @@ class ScanUserEvent(StationEventBase):
 
 
     def start(self, qr_code: QRCode):
     def start(self, qr_code: QRCode):
         self._qr_code = qr_code
         self._qr_code = qr_code
-        self.thread = TkThreading(self.func, qr_code, self._db)
+        self.thread = TkThreading(self.func, qr_code)
         return self
         return self
 
 
     def is_end(self) -> bool:
     def is_end(self) -> bool:
         return self.thread is not None and not self.thread.is_alive()
         return self.thread is not None and not self.thread.is_alive()
 
 
     def done_after_event(self):
     def done_after_event(self):
-        self.thread.join()
-        if self.thread.result is not None:
-            self.station.switch_user(self.thread.result)
-            self.station.update_control()
-        else:
+        res, user = self.thread.wait_event()
+        if res is None:
             event = ScanGarbageEvent(self.station).start(self._qr_code)
             event = ScanGarbageEvent(self.station).start(self._qr_code)
             self.station.push_event(event)
             self.station.push_event(event)
+        if res:
+            self.station.switch_user(user)
+            self.station.update_control()
+        else:
+            self.station.logout_user()
+            self.station.update_control()
 
 
 
 
 class ScanGarbageEvent(StationEventBase):
 class ScanGarbageEvent(StationEventBase):