1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
   | import numpy as np from collections import Counter from graphviz import Digraph import os os.environ["PATH"] += os.pathsep + 'D:/Graphviz2.39/bin' def loadData(filename):     with open(filename) as f:         line = f.readline()         res = []         while line:             c = line.split()             content = [float(x) for x in c]             res.append(content)             line = f.readline()          return res def getEnt(data):     num = len(data)     labelCount = {}     for feature in data:         label = feature[-1]         if label not in labelCount.keys():             labelCount[label] = 0         labelCount[label] += 1     Ent = 0     for key, p in labelCount.items():         p = p / num         Ent -= p*np.log2(p)     return Ent def splitDate(data,feature,point):     data1 = [x for x in data if x[feature]<point]     data2 = [x for x in data if x[feature]>point]     return data1,data2 def chooseBF(data):     featureValue = []     num = len(data)     Ent = getEnt(data)     numFeature = len(data[0])-1     maxGain = float('-inf')     for i in range(numFeature):         featureList = [feature[i] for feature in data]         featureList = sorted(list(set(featureList)))         for j in range(len(featureList)-1):             data1,data2 = splitDate(data,i,(featureList[j]+featureList[j+1])/2)             Ent1 = getEnt(data1)             Ent2 = getEnt(data2)             Gain = Ent - (len(data1)/num)*Ent1 - (len(data2)/num)*Ent2             if Gain>maxGain:                 feature = i                 point = (featureList[j]+featureList[j+1])/2                 dataLeft = data1                 dataRight = data2                 maxGain = Gain     return feature,point,dataLeft,dataRight
 
  def creatTree(data):     node = {}     label = [sample[-1] for sample in data]     if len(set(label))==1:         node['label'] = label[0]         return node     unique = []     for i in range(len(data[0])-1):         unique.append(len(set([sample[i] for sample in data])))     tot = sum([1 for x in unique if (x == 1)])     dataCount = Counter([sample[-1] for sample in data])     if (tot == 4):         node['label'] = list(dataCount.most_common(1)[0])[0]         return node     featureIndex,point,dataLeft,dataRight = chooseBF(data)     node['value'] = point     node['feature'] = featureIndex     node['leftChild'] = creatTree(dataLeft)     node['rightChild'] = creatTree(dataRight)     return node
  def predict(sample,node):     if 'feature' in node:         if sample[node['feature']]>node['value']:             label = predict(sample,node['rightChild'])         else:             label = predict(sample,node['leftChild'])     else:         return node['label']     return label
 
  def test(data,tree):     num = len(data)     ans = []     for i in range(num):         ans.append(predict(data[i],tree))     return ans
 
  def plot_model(tree, name):     g = Digraph("G", filename=name, format='png', strict=False)     g.node("0", str(tree['feature']))     _sub_plot(g, tree, "0")     return g     g.view()
 
  root = "0"
 
  def _sub_plot(g, tree, inc):     global root     ts = tree     for i in ts.keys():         if i == 'leftChild':               root = str(int(root) + 1)             if 'feature' in tree[i]:                 g.node(root,str(tree[i]['feature']))             else:                 g.node(root, str(tree[i]['label']))               g.edge(inc, root, '<' + str(tree['value']))             _sub_plot(g, tree[i], root)         if i == 'rightChild':             root = str(int(root) + 1)             if 'feature' in tree[i]:                 g.node(root, str(tree[i]['feature']))               else:                 g.node(root, str(tree[i]['label']))               g.edge(inc, root, '>' + str(tree['value']))             _sub_plot(g, tree[i], root)         if i == 'label':             g.node(root, 'label:' + str(tree['label']))  
 
 
  if __name__ == '__main__':     trainData = loadData('traindata.txt')     testData = loadData('testdata.txt')     tree = creatTree(trainData)     ans = test(testData,tree)     trueLabel = [sample[-1] for sample in testData]     rightCount = 0     for i in range(len(ans)):         if ans[i]==trueLabel[i]:             rightCount = rightCount + 1     g = plot_model(tree, "决策树")     g.view()
      print(rightCount/len(ans))
   |