ソースを参照

feat: aliyun封装面向对象

SongZihuan 3 年 前
コミット
4b6b85e50a
7 ファイル変更96 行追加60 行削除
  1. 29 25
      conf/args.py
  2. 10 8
      conf/matplotlib_conf.py
  3. 29 12
      equipment/aliyun.py
  4. 1 1
      equipment/scan.py
  5. 10 1
      main.py
  6. 7 5
      sql/mysql_db.py
  7. 10 8
      tk_ui/station.py

+ 29 - 25
conf/args.py

@@ -27,33 +27,37 @@ p_args = parser.parse_args()
 
 if p_args.mysql_url is None or p_args.mysql_name is None or p_args.mysql_passwd is None:
     res = os.environ.get('HGSSystem_MySQL')
-    if res is None:
-        warnings.warn("未找到MySQL地址")
-        exit(1)
-    res = res.split(';')
-    if len(res) == 4:
-        p_args.mysql_url = [res[0]]
-        p_args.mysql_name = [res[1]]
-        p_args.mysql_passwd = [res[2]]
-        p_args.mysql_port = [res[3]]
-    elif len(res) == 3:
-        p_args.mysql_url = [res[0]]
-        p_args.mysql_name = [res[1]]
-        p_args.mysql_passwd = [res[2]]
-        p_args.mysql_port = [None]
+    if res is not None:
+        res = res.split(';')
+        if len(res) == 4:
+            p_args.mysql_url = [res[0]]
+            p_args.mysql_name = [res[1]]
+            p_args.mysql_passwd = [res[2]]
+            p_args.mysql_port = [res[3]]
+        elif len(res) == 3:
+            p_args.mysql_url = [res[0]]
+            p_args.mysql_name = [res[1]]
+            p_args.mysql_passwd = [res[2]]
+            p_args.mysql_port = [None]
+        else:
+            warnings.warn("MYSQL地址错误")
+            exit(1)
     else:
-        warnings.warn("MYSQL地址错误")
-        exit(1)
+        p_args.mysql_url = [None]
+        p_args.mysql_name = [None]
+        p_args.mysql_passwd = [None]
+        p_args.mysql_port = [None]
 
 if p_args.aliyun_key is None or p_args.aliyun_secret is None:
     res = os.environ.get('HGSSystem_Aliyun')
-    if res is None:
-        warnings.warn("未找到阿里云认证")
-        exit(1)
-    res = res.split(';')
-    if len(res) == 2:
-        p_args.aliyun_key = [res[0]]
-        p_args.aliyun_secret = [res[1]]
+    if res is not None:
+        res = res.split(';')
+        if len(res) == 2:
+            p_args.aliyun_key = [res[0]]
+            p_args.aliyun_secret = [res[1]]
+        else:
+            warnings.warn("阿里云认证错误")
+            exit(1)
     else:
-        warnings.warn("阿里云认证错误")
-        exit(1)
+        p_args.aliyun_key = [None]
+        p_args.aliyun_secret = [None]

+ 10 - 8
conf/matplotlib_conf.py

@@ -1,6 +1,3 @@
-import matplotlib.font_manager as fm
-
-
 class ConfigMatplotlibRelease:
     matplotlib_font = "SimHei"
     matplotlib_font_dict = dict(family=matplotlib_font)
@@ -8,8 +5,13 @@ class ConfigMatplotlibRelease:
 
 ConfigMatplotlib = ConfigMatplotlibRelease
 
-if "SimHei" not in [f.name for f in fm.fontManager.ttflist]:
-    print("请安装SimHei字体")
-    exit(1)
-fm.rcParams["font.sans-serif"] = [ConfigMatplotlib.matplotlib_font]  # 配置中文字体
-fm.rcParams["axes.unicode_minus"] = False  # 解决负号变豆腐块
+try:
+    import matplotlib.font_manager as fm
+except ImportError:
+    pass
+else:
+    if "SimHei" not in [f.name for f in fm.fontManager.ttflist]:
+        print("请安装SimHei字体")
+        exit(1)
+    fm.rcParams["font.sans-serif"] = [ConfigMatplotlib.matplotlib_font]  # 配置中文字体
+    fm.rcParams["axes.unicode_minus"] = False  # 解决负号变豆腐块

+ 29 - 12
equipment/aliyun.py

@@ -4,22 +4,39 @@
 import json
 
 from conf import Config
+from tool.type_ import *
 from viapi.fileutils import FileUtils
 from aliyunsdkcore.client import AcsClient
 from aliyunsdkimagerecog.request.v20190930 import ClassifyingRubbishRequest
-from aliyunsdkcore.acs_exception.exceptions import ClientException, ServerException
+from aliyunsdkcore.acs_exception.exceptions import ClientException, ServerException  # 外部需要使用
 
 
-def oss_file(file, suffix, is_local: bool = True):
-    """调用临时对象存储"""
-    file_utils = FileUtils(Config.aliyun_key, Config.aliyun_secret)
-    return file_utils.get_oss_url(file, suffix, is_local)
+class AliyunClientException(ClientException):
+    ...
 
 
-def garbage_search(img_url: str) -> dict:
-    """搜索图片是否为垃圾"""
-    client = AcsClient(Config.aliyun_key, Config.aliyun_secret, Config.aliyun_region_id)
-    response = ClassifyingRubbishRequest.ClassifyingRubbishRequest()
-    response.set_ImageURL(img_url)
-    res: bytes = client.do_action_with_exception(response)
-    return json.loads(res.decode('utf-8'))
+class AliyunServerException(ServerException):
+    ...
+
+
+class Aliyun:
+    def __init__(self,
+                 key: Optional[str] = Config.aliyun_key,
+                 secret: Optional[str] = Config.aliyun_secret,
+                 region_id: Optional[str] = Config.aliyun_region_id):
+        self._key = key
+        self._secret = secret
+        self._region_id = region_id
+
+    def oss_file(self, file, suffix, is_local: bool = True):
+        """调用临时对象存储"""
+        file_utils = FileUtils(self._key, self._secret)
+        return file_utils.get_oss_url(file, suffix, is_local)
+
+    def garbage_search(self, img_url: str) -> dict:
+        """搜索图片是否为垃圾"""
+        client = AcsClient(self._key, self._secret, self._region_id)
+        response = ClassifyingRubbishRequest.ClassifyingRubbishRequest()
+        response.set_ImageURL(img_url)
+        res: bytes = client.do_action_with_exception(response)
+        return json.loads(res.decode('utf-8'))

+ 1 - 1
equipment/scan.py

@@ -1,6 +1,6 @@
 import time
 import threading
-import cv2 as cv2
+import cv2.cv2 as cv2
 
 from conf import Config
 import qrcode

+ 10 - 1
main.py

@@ -31,9 +31,17 @@ def main():
 
     from sql.db import DB
 
+    if Config.mysql_url is None or Config.mysql_name is None:
+        print("MySQL 错误")
+        exit(1)
     mysql = DB()
 
     if program_name == "garbage":
+        from equipment.aliyun import Aliyun
+        if Config.aliyun_key is None or Config.aliyun_secret is None:
+            print("Aliyun key 错误")
+            exit(1)
+
         try:
             from equipment.scan import HGSCapture, HGSQRCoder
             import tk_ui.station as garbage_station
@@ -41,9 +49,10 @@ def main():
             can_not_load("垃圾站系统")
             sys.exit(1)
 
+        aliyun = Aliyun()
         cap = HGSCapture()
         qr = HGSQRCoder(cap)
-        station = garbage_station.GarbageStation(mysql, cap, qr)
+        station = garbage_station.GarbageStation(mysql, cap, qr, aliyun)
         station.mainloop()
     elif program_name == "ranking":
         try:

+ 7 - 5
sql/mysql_db.py

@@ -3,16 +3,18 @@ import threading
 import traceback
 
 from conf import Config
-from .base_db import HGSDatabase, DBCloseException
+from .base_db import HGSDatabase, DBException, DBCloseException
 from tool.type_ import *
 
 
 class MysqlDB(HGSDatabase):
     def __init__(self,
-                 host: str = Config.mysql_url,
-                 name: str = Config.mysql_name,
-                 passwd: str = Config.mysql_passwd,
-                 port: str = Config.mysql_port):
+                 host: Optional[str] = Config.mysql_url,
+                 name: Optional[str] = Config.mysql_name,
+                 passwd: Optional[str] = Config.mysql_passwd,
+                 port: Optional[str] = Config.mysql_port):
+        if host is None or name is None:
+            raise DBException
         super(MysqlDB, self).__init__(host, name, passwd, port)
         try:
             self._db = pymysql.connect(user=self._name,

+ 10 - 8
tk_ui/station.py

@@ -1,7 +1,7 @@
 import os.path
 import time
 import tempfile
-import cv2
+import cv2.cv2 as cv2
 import random
 import traceback
 import abc
@@ -23,7 +23,7 @@ from sql.user import update_user, find_user_by_id
 from sql.garbage import update_garbage
 
 from equipment.scan import HGSCapture, HGSQRCoder
-from equipment.aliyun import oss_file, garbage_search, ClientException, ServerException
+from equipment.aliyun import Aliyun, AliyunClientException, AliyunServerException
 
 from .event import TkEventMain
 
@@ -50,10 +50,12 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
                  db: DB,
                  cap: HGSCapture,
                  qr: HGSQRCoder,
+                 aliyun: Aliyun,
                  loc: location_t = Config.base_location):
         self._db: DB = db
         self._cap: HGSCapture = cap
         self._qr: HGSQRCoder = qr
+        self._aliyun: Aliyun = aliyun
         self._loc: location_t = loc
 
         self._user: Optional[User] = None  # 操作者
@@ -365,12 +367,11 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
         event = tk_event.RankingEvent(self)
         self.push_event(event)
 
-    @staticmethod
-    def search_core(temp_dir: tempfile.TemporaryDirectory, file_path: str) -> Optional[Dict]:
+    def search_core(self, temp_dir: tempfile.TemporaryDirectory, file_path: str) -> Optional[Dict]:
         try:
-            img_url = oss_file(file_path, "jpg", True)
-            res = garbage_search(img_url)
-        except (ClientException, ServerException):
+            img_url = self._aliyun.oss_file(file_path, "jpg", True)
+            res = self._aliyun.garbage_search(img_url)
+        except (AliyunClientException, AliyunServerException):
             return None
         else:
             return res
@@ -511,11 +512,12 @@ class GarbageStation(GarbageStationBase):
                  db: DB,
                  cap: HGSCapture,
                  qr: HGSQRCoder,
+                 aliyun: Aliyun,
                  loc: location_t = Config.base_location,
                  refresh_delay: int = Config.tk_refresh_delay):
         self.init_after_run_list: List[Tuple[int, Callable, Tuple]] = []
 
-        super(GarbageStation, self).__init__(db, cap, qr, loc)
+        super(GarbageStation, self).__init__(db, cap, qr, aliyun, loc)
         self.refresh_delay = refresh_delay
 
         self._window = tk.Tk()  # 系统窗口