浏览代码

绘图方法:预测型热力图

Huan 5 年之前
父节点
当前提交
06b282cf40
共有 1 个文件被更改,包括 47 次插入0 次删除
  1. 47 0
      Learn_Numpy.py

+ 47 - 0
Learn_Numpy.py

@@ -879,6 +879,53 @@ class UnsupervisedModel(prep_Base):
         self.Model.fit(x_data)
         return 'None', 'None'
 
+class Predictive_HeatMap(prep_Base):#绘制预测型热力图
+    def __init__(self, args_use, Learner, *args, **kwargs):  # model表示当前选用的模型类型,Alpha针对正则化的参数
+        super(Predictive_HeatMap, self).__init__(*args, **kwargs)
+
+        self.Model = Learner.Model
+        self.Select_Model = None
+        self.have_Fit = Learner.have_Fit
+        self.Model_Name = 'Select_Model'
+        self.x_trainData = self.x_trainData
+        self.y_trainData = self.y_trainData
+
+    def Des(self,Dic,*args,**kwargs):
+        tab = Tab()
+        y = self.y_trainData
+        x_data = self.x_trainData
+        try:#如果没有class
+            class_ = self.Model.classes_.tolist()
+            class_heard = [f'类别[{i}]' for i in range(len(class_))]
+
+            #获取数据
+            get,x_means,x_range,Type = Training_visualization(x_data,class_,y)
+
+            get = Decision_boundary(x_range,x_means,self.Model.Predict,class_,Type)
+            for i in range(len(get)):
+                tab.add(get[i], f'{i}预测热力图')
+
+            heard = class_heard + [f'普适预测第{i}特征' for i in range(len(x_means))]
+            data = class_ + [f'{i}' for i in x_means]
+            c = Table().add(headers=heard, rows=[data])
+            tab.add(c, '数据表')
+        except:
+            get, x_means, x_range,Type = regress_visualization(x_data, y)
+
+            get = Prediction_boundary(x_range, x_means, self.Model.Predict, Type)
+            for i in range(len(get)):
+                tab.add(get[i], f'{i}预测热力图')
+
+            heard = [f'普适预测第{i}特征' for i in range(len(x_means))]
+            data = [f'{i}' for i in x_means]
+            c = Table().add(headers=heard, rows=[data])
+            tab.add(c, '数据表')
+
+        save = Dic + r'/render.HTML'
+        tab.render(save)  # 生成HTML
+        return save,
+
+
 class Near_feature_scatter_class_More(Unsupervised):
     def __init__(self, args_use, model, *args, **kwargs):
         super(Near_feature_scatter_class_More, self).__init__(*args, **kwargs)