Преглед изворни кода

feat: 新增阿里云垃圾识别

SongZihuan пре 3 година
родитељ
комит
19036bbf31
10 измењених фајлова са 198 додато и 18 уклоњено
  1. 5 1
      conf/__init__.py
  2. 10 0
      conf/aliyun.py
  3. 16 0
      conf/args.py
  4. 1 0
      conf/sys_default.py
  5. 25 0
      equipment/aliyun.py
  6. 1 1
      tk_ui/ranking.py
  7. 105 14
      tk_ui/station.py
  8. 30 1
      tk_ui/station_event.py
  9. 4 0
      tool/thread_.py
  10. 1 1
      tool/type_.py

+ 5 - 1
conf/__init__.py

@@ -8,10 +8,14 @@ from picture import head_pic, rank_bg_pic
 from args import p_args
 from .equipment import ConfigCapture
 from .sql import ConfigDatabase
+from .aliyun import ConfigAliyun
 from .sys_default import ConfigExport, ConfigSystem, ConfigSecret, ConfigTkinter, ConfUser
 
 
-class Config(ConfigTkinter, ConfigSecret, ConfigSystem, ConfUser, ConfigExport, ConfigDatabase, ConfigCapture):
+class Config(ConfigTkinter, ConfigSecret, ConfigSystem, ConfUser, ConfigExport,
+             ConfigAliyun,
+             ConfigDatabase,
+             ConfigCapture):
     run_type = p_args.run[0]
     program = p_args.program[0]
 

+ 10 - 0
conf/aliyun.py

@@ -0,0 +1,10 @@
+from . import args
+
+
+class ConfigAliyunRelease:
+    aliyun_key = args.p_args.aliyun_key[0]
+    aliyun_secret = args.p_args.aliyun_secret[0]
+    aliyun_region_id = "cn-shanghai"
+
+
+ConfigAliyun = ConfigAliyunRelease

+ 16 - 0
conf/args.py

@@ -8,6 +8,9 @@ parser.add_argument("--mysql_name", nargs=1, required=False, type=str, default=N
 parser.add_argument("--mysql_passwd", nargs=1, required=False, type=str, default=None, help="MySQL-密码")
 parser.add_argument("--mysql_port", nargs=1, required=False, type=str, default=[None], help="MySQL-端口")
 
+parser.add_argument("--aliyun_key", nargs=1, required=False, type=str, default=None, help="阿里云认证-KET")
+parser.add_argument("--aliyun_secret", nargs=1, required=False, type=str, default=None, help="阿里云认证-SECRET")
+
 parser.add_argument("--program", nargs=1, required=True, type=str, choices=["setup",
                                                                             "garbage",
                                                                             "ranking",
@@ -41,3 +44,16 @@ if p_args.mysql_url is None or p_args.mysql_name is None or p_args.mysql_passwd
     else:
         warnings.warn("MYSQL地址错误")
         exit(1)
+
+if p_args.aliyun_key is None or p_args.aliyun_secret is None:
+    res = os.environ.get('HGSSystem_Aliyun')
+    if res is None:
+        warnings.warn("未找到阿里云认证")
+        exit(1)
+    res = res.split(';')
+    if len(res) == 2:
+        p_args.aliyun_key = [res[0]]
+        p_args.aliyun_secret = [res[1]]
+    else:
+        warnings.warn("阿里云认证错误")
+        exit(1)

+ 1 - 0
conf/sys_default.py

@@ -12,6 +12,7 @@ class ConfUserRelease:
 
 class ConfigSystemRelease:
     base_location = "Guangdong-KZ"
+    search_reset_time = 10  # 搜索间隔的时间
     about_info = f'''
 HGSSystem is Garbage Sorting System
 

+ 25 - 0
equipment/aliyun.py

@@ -0,0 +1,25 @@
+"""阿里云 SDK调用 封装"""
+"""依赖模块: oss2, aliyun-python-sdk-viapiutils, viapi-utils, aliyun-python-sdk-imagerecog"""
+
+import json
+
+from conf import Config
+from viapi.fileutils import FileUtils
+from aliyunsdkcore.client import AcsClient
+from aliyunsdkimagerecog.request.v20190930 import ClassifyingRubbishRequest
+from aliyunsdkcore.acs_exception.exceptions import ClientException, ServerException
+
+
+def oss_file(file, suffix, is_local: bool = True):
+    """调用临时对象存储"""
+    file_utils = FileUtils(Config.aliyun_key, Config.aliyun_secret)
+    return file_utils.get_oss_url(file, suffix, is_local)
+
+
+def garbage_search(img_url: str) -> dict:
+    """搜索图片是否为垃圾"""
+    client = AcsClient(Config.aliyun_key, Config.aliyun_secret, Config.aliyun_region_id)
+    response = ClassifyingRubbishRequest.ClassifyingRubbishRequest()
+    response.set_ImageURL(img_url)
+    res: bytes = client.do_action_with_exception(response)
+    return json.loads(res.decode('utf-8'))

+ 1 - 1
tk_ui/ranking.py

@@ -270,7 +270,7 @@ class RankingStation(RankingStationBase):
         self.window.bind("<F11>", lambda _: self.__switch_full_screen())
 
     def __conf_windows_bg(self):
-        img = Image.open(Config.picture_d['rank_bg']).resize((self.width, self.height))
+        img = Image.open(Config.picture_d['rank_bg']).resize((self.width, self.height), Image.ANTIALIAS)
         self.bg_img = ImageTk.PhotoImage(img)
         self.bg_lb['image'] = self.bg_img
         self.bg_lb.place(relx=0, rely=0, relwidth=1, relheight=1)

+ 105 - 14
tk_ui/station.py

@@ -1,4 +1,6 @@
+import os.path
 import time
+import tempfile
 import cv2
 import random
 import traceback
@@ -11,6 +13,7 @@ from PIL import Image, ImageTk
 from conf import Config
 from tool.type_ import *
 from tool.tk import set_tk_disable_from_list, make_font
+from tool.thread_ import getThreadIdent
 
 from core.user import User
 from core.garbage import GarbageBag, GarbageType
@@ -20,6 +23,7 @@ from sql.user import update_user, find_user_by_id
 from sql.garbage import update_garbage
 
 from equipment.scan import HGSCapture, HGSQRCoder
+from equipment.aliyun import oss_file, garbage_search, ClientException, ServerException
 
 from .event import TkEventMain
 
@@ -62,6 +66,8 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
 
         self.rank = None
         self.rank_index = 0
+
+        self.search_time = 0  # 上次执行搜索任务的时间
         super(GarbageStationBase, self).__init__()
 
     def get_db(self):
@@ -309,13 +315,6 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
 [期待再次与你相遇]
                 '''.strip(), show_time=15.0)
 
-    def show_search_info(self):
-        self.update_user_time()
-        self.show_msg("搜索", f'''
-搜索功能将根据摄像头获取物品信息, 反馈该物品垃圾类型。
-该功能尚未开放, 敬请期待
-                '''.strip(), show_time=5.0)
-
     def thread_show_rank(self, rank_list):
         self.rank = [[]]
         for i, r in enumerate(rank_list):
@@ -366,12 +365,81 @@ class GarbageStationBase(TkEventMain, metaclass=abc.ABCMeta):
         event = tk_event.RankingEvent(self)
         self.push_event(event)
 
+    @staticmethod
+    def search_core(temp_dir: tempfile.TemporaryDirectory, file_path: str) -> Optional[Dict]:
+        try:
+            img_url = oss_file(file_path, "jpg", True)
+            res = garbage_search(img_url)
+        except (ClientException, ServerException):
+            return None
+        else:
+            return res
+        finally:
+            temp_dir.cleanup()
+
+    def get_search_result(self, res: dict) -> bool:
+        self.search_time = time.time()
+        data: Optional[dict] = res.get("Data")
+        if data is None:
+            self.show_warning("搜索垃圾", "搜索垃圾时发生错误")
+            return False
+
+        sensitive = data.get("Sensitive")
+        if sensitive is None:
+            self.show_warning("搜索垃圾", "搜索垃圾时发生错误")
+            return False
+        elif sensitive:
+            self.show_warning("搜索垃圾", "图片不够清晰")  # 图片违规
+            return False
+        elements: List[Dict] = data.get("Elements")
+        assert elements is not None
+
+        res_str = f"搜索结果为 [共{len(elements)}项]:\n\n"
+        for i, element in enumerate(elements):
+            name_ = element.get("Rubbish")
+            name_score = element.get("RubbishScore")
+            if len(name_) == 0:
+                name = "未知物品"
+            else:
+                name = f"{name_} [可信度: {name_score * 100}%]"
+            if name_score < 0.001:
+                name_score = 0.01
+            category_ = element.get("Category")
+            category_score = element.get("CategoryScore")
+            if len(category_) == 0:
+                category = "未知垃圾类型"
+            else:
+                category_ = {"可回收垃圾": "可回收垃圾",
+                             "干垃圾": "其他垃圾",
+                             "湿垃圾": "厨余垃圾",
+                             "有害垃圾": "有害垃圾"}.get(category_, "未知垃圾类型")
+                category = f"垃圾类型为{category_} [可信度: {(category_score / name_score) * 100}%]"
+            res_str += f"  NO.{i + 1} {name}\n  {category}\n"
+        self.show_msg("搜索垃圾", res_str, show_time=30, big=True)
+
+    def search_pic(self, img: Image.Image) -> bool:
+        sep = time.time() - self.search_time
+        if sep < 3:
+            self.show_warning("搜索垃圾", f"搜索太频繁了\n请稍后再尝试")
+            return False
+        elif sep <= Config.search_reset_time:
+            self.show_warning("搜索垃圾", f"搜索太频繁了\n请{sep}s后再尝试")
+            return False
+
+        temp_dir = tempfile.TemporaryDirectory()
+        tid = getThreadIdent()
+        file_path = os.path.join(temp_dir.name, f"search-{tid}.jpg")
+        img.save(file_path, 'JPEG', quality=100)
+        event = tk_event.SearchGarbageEvent(self).start(temp_dir, file_path)
+        self.push_event(event)
+        return True
+
     @abc.abstractmethod
     def show_msg(self, title, info, msg_type='info', big: bool = True, show_time: float = 10.0):
         ...
 
     @abc.abstractmethod
-    def show_warning(self, title, info, show_time: float = 15.0):
+    def show_warning(self, title, info, show_time: float = 5.0):
         ...
 
     @abc.abstractmethod
@@ -458,6 +526,7 @@ class GarbageStation(GarbageStationBase):
         self.__conf_windows()
 
         self._cap_img = None  # 存储 PIL.image 的变量 防止gc释放
+        self._cap_img_tk = None  # 存储 tkinter的image 的变量 防止gc释放
         self._user_im = None
 
         self._msg_time: Optional[float] = None  # msg 显示时间累计
@@ -517,6 +586,8 @@ class GarbageStation(GarbageStationBase):
 
         # 摄像头显示
         self._cap_label = tk.Label(self._window)
+        self._cap_width = 0
+        self._cap_height = 0
 
         # 用户操纵按钮
         self._user_btn_frame = tk.Frame(self._window)
@@ -860,13 +931,15 @@ class GarbageStation(GarbageStationBase):
             height += height_label + height_sep
 
         self._user_btn[0]['state'] = 'disable'  # 第一个按键默认为disable且点击无效果
-        self._user_btn[1]['command'] = lambda: self.get_show_rank()
-        self._user_btn[2]['command'] = lambda: self.show_search_info()
+        self._user_btn[1]['command'] = self.get_show_rank
+        self._user_btn[2]['command'] = self.search_pic
 
     def __conf_cap_label(self):
         self._cap_label['bg'] = "#000000"
         self._cap_label['bd'] = 5
         self._cap_label['relief'] = "ridge"
+        self._cap_width = int(self._win_width * 0.2)
+        self._cap_height = int(self._win_height * 0.32)
         self._cap_label.place(relx=0.22, rely=0.66, relwidth=0.2, relheight=0.32)
 
     def __conf_msg(self):
@@ -954,7 +1027,7 @@ class GarbageStation(GarbageStationBase):
 
         self.set_msg_time_now(show_time)
 
-    def show_warning(self, title, info, show_time: float = 15.0):
+    def show_warning(self, title, info, show_time: float = 5.0):
         self.show_msg(title, info, msg_type='警告', show_time=show_time)
 
     def __conf_rank(self):
@@ -1092,6 +1165,11 @@ class GarbageStation(GarbageStationBase):
         self._loading_pro.stop()
         self.set_reset_all_btn()
 
+    def search_pic(self, img: Image = None):
+        if img is None:
+            img = self._cap_img
+        super(GarbageStation, self).search_pic(img)
+
     def __show_check_frame(self):
         self._check_ctrl_frame.place(relx=0.45, rely=0.82, relwidth=0.53, relheight=0.16)
 
@@ -1177,8 +1255,22 @@ class GarbageStation(GarbageStationBase):
         # 需要存储一些数据 谨防被gc释放
         _cap_img_info = (Image.fromarray(cv2.cvtColor(self.get_cap_img(), cv2.COLOR_BGR2RGB)).
                          transpose(Image.FLIP_LEFT_RIGHT))
-        self._cap_img = ImageTk.PhotoImage(image=_cap_img_info)
-        self._cap_label['image'] = self._cap_img
+        self._cap_img = _cap_img_info
+
+        img_width, img_height = _cap_img_info.size
+        proportion = max(self._cap_width / img_width, self._cap_height / img_height)  # 缩放倍数, 取较大的那个
+        new_width = int(img_width * proportion)
+        new_height = int(img_height * proportion)
+        _cap_img_info = _cap_img_info.resize((new_width, new_height), Image.ANTIALIAS)
+
+        crop = (int(new_width / 2 - self._cap_width / 2),  # 左
+                int(new_height / 2 - self._cap_height / 2),  # 上
+                int(new_width / 2 + self._cap_width / 2),  # 右
+                int(new_height / 2 + self._cap_height / 2))  # 下
+        _cap_img_info = _cap_img_info.crop(crop)  # 裁剪图片
+
+        self._cap_img_tk = ImageTk.PhotoImage(image=_cap_img_info)
+        self._cap_label['image'] = self._cap_img_tk
 
     def update_msg(self):
         if self._msg_time is None:
@@ -1240,4 +1332,3 @@ class GarbageStation(GarbageStationBase):
 
     def exit_win(self):
         self._window.destroy()
-

+ 30 - 1
tk_ui/station_event.py

@@ -1,10 +1,12 @@
+import tempfile
+
 from equipment.scan import QRCode
 from equipment.scan_user import scan_user
 from equipment.scan_garbage import scan_garbage
 
 from tool.type_ import *
 
-from core.user import User, UserNotSupportError
+from core.user import User
 from core.garbage import GarbageBag
 
 from sql.db import DB
@@ -192,3 +194,30 @@ class CheckGarbageEvent(StationEventBase):
             self.station.show_warning("垃圾检测", "数据库操作失败", show_time=3.0)
         else:
             self.station.show_msg("垃圾检测", "垃圾检测提结果交成功", show_time=3.0)
+
+
+class SearchGarbageEvent(StationEventBase):
+    """
+    任务: 搜索垃圾垃圾的结果
+    """
+
+    def func(self, temp_dir: tempfile.TemporaryDirectory, file_path: str):
+        return self.station.search_core(temp_dir, file_path)
+
+    def __init__(self, gb_station):
+        super().__init__(gb_station, "搜索垃圾")
+        self.thread = None
+
+    def start(self, temp_dir: tempfile.TemporaryDirectory, file_path: str):
+        self.thread = TkThreading(self.func, temp_dir, file_path)
+        return self
+
+    def is_end(self) -> bool:
+        return not self.thread.is_alive()
+
+    def done_after_event(self):
+        res = self.thread.wait_event()
+        if res is None:
+            self.station.show_warning("垃圾搜索", "垃圾搜索发生错误")
+        else:
+            self.station.get_search_result(res)

+ 4 - 0
tool/thread_.py

@@ -36,3 +36,7 @@ class Threading(threading.Thread):
         """
         self.join()
         return self.result
+
+
+def getThreadIdent():
+    return threading.currentThread().ident

+ 1 - 1
tool/type_.py

@@ -1,4 +1,4 @@
-from typing import Dict, List, Tuple, Union, Optional, Callable
+from typing import Dict, List, Tuple, Union, Optional, Callable, IO
 
 gid_t = str  # garbage bag id 类型
 uid_t = str  # user id 类型