Pārlūkot izejas kodu

feat: 优化初始化程序

SongZihuan 3 gadi atpakaļ
vecāks
revīzija
b111759856
2 mainītis faili ar 73 papildinājumiem un 61 dzēšanām
  1. 10 10
      main.py
  2. 63 51
      setup.py

+ 10 - 10
main.py

@@ -12,10 +12,14 @@ from conf import Config
 
 
 def can_not_load(name):
-    print(f"无法加载 {name} 模块, 该系统可能不存在", file=sys.stderr)
+    print(f"无法加载 {name} 系统, 该系统或其依赖可能不存在", file=sys.stderr)
 
 
 def main():
+    if Config.mysql_url is None or Config.mysql_name is None:
+        print("请提供MySQL信息")
+        sys.exit(1)
+
     program_name = Config.program
     if program_name == "setup":  # setup程序不需要数据库链接等操作
         __main = os.path.dirname(os.path.abspath(__file__))
@@ -25,22 +29,18 @@ def main():
                         f"--mysql_passwd={Config.mysql_passwd} "
                         f"--program=setup")
         if res != 0:
-            print("初始化程序加载失败", file=sys.stderr)
-            exit(1)
-        exit(0)
+            print("初始化程序加载失败, 请检查配置是否正确而", file=sys.stderr)
+            sys.exit(1)
+        sys.exit(0)
 
     from sql.db import DB
-
-    if Config.mysql_url is None or Config.mysql_name is None:
-        print("MySQL 错误")
-        exit(1)
     mysql = DB()
 
     if program_name == "garbage":
         from equipment.aliyun import Aliyun
         if Config.aliyun_key is None or Config.aliyun_secret is None:
-            print("Aliyun key 错误")
-            exit(1)
+            print("请提供Aliyun key信息")
+            sys.exit(1)
 
         try:
             from equipment.scan import HGSCapture, HGSQRCoder

+ 63 - 51
setup.py

@@ -3,17 +3,18 @@ import sys
 import time
 from typing import Union, List
 
+print("初始化程序开始执行")
 print("开始检查依赖")
 
 try:
     __import__("pip")
 except ImportError:
     print("检查结束, 未找到pip")
-    exit(1)
+    sys.exit(1)
 else:
     print("依赖 pip 存在")
     if os.system(f"{sys.executable} -m pip install --upgrade pip") != 0:
-        print(f"依赖 pip 更新失败")
+        print(f"依赖 pip 更新失败", file=sys.stderr)
     else:
         print(f"依赖 pip 更新成功")
 
@@ -31,9 +32,11 @@ def check_import(packages: Union[str, List[str]], pips: Union[str, List[str]]):
             print(f"依赖 {package} 存在")
     except ImportError:
         for pip in pips:
-            if os.system(f"{sys.executable} -m pip install {pip}") != 0:
-                print(f"{pip} 依赖安装失败")
-                exit(1)
+            command = f"{sys.executable} -m pip install {pip}"
+            print(f"依赖 {pip} 安装: {command}")
+            if os.system(command) != 0:
+                print(f"依赖 {pip} 安装失败", file=sys.stderr)
+                sys.exit(1)
             else:
                 print(f"依赖 {packages}:{pip} 安装成功")
 
@@ -50,11 +53,6 @@ check_import("matplotlib", "matplotlib")  # matplotlib依赖
 check_import(["oss2", "viapi", "aliyunsdkcore", "aliyunsdkimagerecog"],
              ["oss2", "aliyun-python-sdk-viapiutils", "viapi-utils", "aliyun-python-sdk-imagerecog"])  # 阿里云依赖
 
-print("是否执行数据库初始化程序?\n执行初始化程序会令你丢失所有数据.")
-res = input("[Y/n]")
-if res != 'Y':
-    exit(0)
-
 import pymysql
 from conf import Config
 
@@ -62,29 +60,38 @@ mysql_url = Config.mysql_url
 mysql_name = Config.mysql_name
 mysql_passwd = Config.mysql_passwd
 
-sql = pymysql.connect(user=mysql_name, password=mysql_passwd, host=mysql_url)
-cursor = sql.cursor()
-with open(os.path.join(__setup, "setup.sql"), "r", encoding='utf-8') as f:
-    all_sql = f.read().split(';')
-    for s in all_sql:
-        if s.strip() == "":
-            continue
-        cursor.execute(f"{s};")
-    sql.commit()
-
-admin_passwd = input("Enter Admin Passwd: ")
-admin_phone = ""
-while len(admin_phone) != 11:
-    admin_phone = input("Enter Admin Phone[len = 11]: ")
+try:
+    sql = pymysql.connect(user=mysql_name, password=mysql_passwd, host=mysql_url)
+    cursor = sql.cursor()
+except pymysql.err:
+    print("请提供正确的MySQL信息", file=sys.stderr)
+    sys.exit(1)
 
-from tool.login import create_uid
-from tool.time_ import mysql_time
+print("是否执行数据库初始化程序?\n执行初始化程序会令你丢失所有数据.")
+res = input("[Y/n]")
+if res == 'Y' or res == 'y':
+    with open(os.path.join(__setup, "setup.sql"), "r", encoding='utf-8') as f:
+        all_sql = f.read().split(';')
+        for s in all_sql:
+            if s.strip() == "":
+                continue
+            cursor.execute(f"{s};")
+        sql.commit()
+
+    admin_passwd = input("创建 'admin' 管理员的密码: ")
+    admin_phone = ""
+    while len(admin_phone) != 11:
+        admin_phone = input("输入 'admin' 管理员的电话[长度=11]: ")
+
+    from tool.login import create_uid
+    from tool.time_ import mysql_time
+
+    # 生成基本 admin 用户
+    uid = create_uid("admin", admin_passwd)
+    cursor.execute(f"INSERT INTO user(UserID, Name, IsManager, Phone, Score, Reputation, CreateTime) "
+                   f"VALUES ('{uid}', 'admin', 1, '{admin_phone}', 10, 300, {mysql_time()});")
+    sql.commit()
 
-# 生成基本 admin 用户
-uid = create_uid("admin", admin_passwd)
-cursor.execute(f"INSERT INTO user(UserID, Name, IsManager, Phone, Score, Reputation, CreateTime) "
-               f"VALUES ('{uid}', 'admin', 1, '{admin_phone}', 10, 300, {mysql_time()});")
-sql.commit()
 
 print("是否伪造数据?")
 if input("[Y/n]") != "Y":
@@ -112,19 +119,17 @@ def random_phone() -> str:
 
 
 def random_time() -> str:
-    r_time = time.time()
-    r_h = random.randint(0, 4 * 24)
-    r_time -= r_h * 60 * 60
-    return mysql_time(r_time)
+    r_time = int(time.time())
+    r_start = int(r_time - 35 * 24 * 60 * 60)
+    return mysql_time(random.randint(r_start, r_time))
 
 
 def random_time_double() -> tuple[str, str]:
-    r_time2 = r_time1 = time.time()
-    r_h1 = random.randint(0, 4 * 24)
-    r_h2 = random.randint(0, 4 * 24)
-    r_time1 -= min(r_h1, r_h2) * 60 * 60
-    r_time2 -= max(r_h1, r_h2) * 60 * 60
-    return mysql_time(r_time1), mysql_time(r_time2)
+    r_time = int(time.time())
+    r_start = int(r_time - 35 * 24 * 60 * 60)
+    r_h1 = random.randint(r_start, r_time)
+    r_h2 = random.randint(r_start, r_time)
+    return mysql_time(min(r_h1, r_h2)), mysql_time(max(r_h1, r_h2))
 
 
 def random_user(r_name, r_passwd, r_phone, r_time, is_manager: int, cur):
@@ -164,12 +169,15 @@ def random_garbage_u(r_time, r_time2, cur):
 
 print("步骤1, 注册管理账户[输入q结束]:")
 while True:
-    name = input("输入用户名:")
-    passwd = input("输入密码:")
-    phone = input("输入手机号码[输入x表示随机]:")
-    creat_time = input("是否随机时间[n=不随机 y=随机]:")
-    if name == 'q' or passwd == 'q' or phone == 'q' or creat_time == 'q':
+    if (name := input("输入用户名:")) == 'q':  # 这里使用了海象表达式, 把赋值运算变成一种表达式
+        break
+    if (passwd := input("输入密码:")) == 'q':
+        break
+    if (phone := input("输入手机号码[输入x表示随机]:")) == 'q':
         break
+    if (creat_time := input("是否随机时间[n=不随机 y=随机]:")) == 'q':
+        break
+
     if phone == 'x':
         phone = random_phone()
     if creat_time == 'n':
@@ -180,16 +188,20 @@ while True:
 
 print("步骤2, 注册普通账户[输入q结束]:")
 while True:
-    name = input("输入用户名:")
-    passwd = input("输入密码:")
-    phone = input("输入手机号码[输入x表示随机]:")
-    creat_time = input("是否随机时间[n=不随机 y=随机]:")
-    if name == 'q' or passwd == 'q' or phone == 'q' or creat_time == 'q':
+    if (name := input("输入用户名:")) == 'q':  # 这里使用了海象表达式, 把赋值运算变成一种表达式
         break
+    if (passwd := input("输入密码:")) == 'q':
+        break
+    if (phone := input("输入手机号码[输入x表示随机]:")) == 'q':
+        break
+    if (creat_time := input("是否随机时间[n=不随机 y=随机]:")) == 'q':
+        break
+
     if creat_time == 'n':
         c_time = mysql_time()
     else:
         c_time = random_time()
+
     if phone == 'x':
         phone = random_phone()
     random_user(name, passwd, phone, c_time, 0, cursor)