__author__ Const import os import numpy as np import numexpr from grap

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
__author__ = 'Const'
import os
import numpy as np
import numexpr
from graphviz import Digraph
from pandas import read_table
from math import log
from collections import defaultdict
dot = Digraph(comment='TREE')
def H(*values):
values = np.asarray(values) * 1.0
values = values / sum(values)
return sum((-value * log(value, 2) for value in values if value > 0.0))
def calc_pn(p, n, P, N):
return (
H(P, N),
H(p, n),
H(P - p, N - n),
H(P, N) - (p + n) * 1.0 / (P + N) * H(p, n) - (P + N - p - n) * 1.0 / (P + N) * H(P - p, N - n),
)
# return H(P, N) - (p + n) * 1.0 / (P + N) * H(p, n) - (P + N - p - n) * 1.0 / (P + N) * H(P - p, N - n)
def read_csv():
return read_table(open('./train2.csv', 'r'), header=0, sep=r"\s*", index_col=0)
class Tree:
def __init__(self):
self.condition = "No"
self.data = {}
self.childs = {}
def __str__(self):
return str(self.condition)
def build_tree(df):
tree = Tree()
return build_tree_recursive(df, tree)
def build_tree_recursive(df, node):
val = dict(df.groupby(['Class'])['Class'].count())
P = val['Yes'] if 'Yes' in val else 0
N = val['No'] if 'No' in val else 0
gain = {}
availible_features = list(df.columns.values[:-1]);
max_column = df.columns.values[0]
for column in df.columns.values[:-1]:
features = defaultdict(lambda: 0)
features.update(dict(df.groupby([column, 'Class'])['Class'].count()))
val = set([x[0] for x in features.keys()])
sub_gain = []
for x in val:
sub_gain.append(calc_pn(features[(x, 'Yes')], features[(x, 'No')], P, N)[3])
gain[column] = max(sub_gain)
max_column = column if gain[column] > gain[max_column] else max_column
for x in dict(df.groupby(max_column)[max_column].count()).keys():
if len(df.query(max_column + " == '" + x + "'").columns.values) == 1: continue
node.childs[x] = Tree()
query_to_go = df.query(max_column + " == '" + x + "'")
value = dict(query_to_go.groupby(['Class'])['Class'].count())
p = value['Yes'] if 'Yes' in value else 0
n = value['No'] if 'No' in value else 0
#node.childs[x].data['Name'] = max_column + " == '" + x + "'"
if calc_pn(p, n, P, N)[2] == 1.0 or p == 0 or n == 0:
node.childs[x].condition = 'Yes' if p > n else 'No'
node.childs[x].data['Name'] = x
node.childs[x].data['Pos'] = p
node.childs[x].data['Neg'] = n
continue
#print(max_column + " == '" + x + "'" + " " + str(calc_pn(p, n, P, N)))
#print(query_to_go)
del query_to_go[max_column]
#print type(df.query(max_column + " == '" + x + "'"))
build_tree_recursive(query_to_go, node.childs[x])
node.data['Name'] = max_column
node.data['Pos'] = P
node.data['Neg'] = N
return node
def print_tree(tree):
print_tree_recursive(tree)
def print_tree_recursive(node):
if node.data['Name']: print("current " + node.data['Name'] + " p: " + str(node.data['Pos']) + " n: " + str(node.data['Neg']))
print("childs :")
dot.node(node.data['Name'] + ", " + str(node.data['Pos']) + ":" + str(node.data['Neg']), node.data['Name'] + ", " + str(node.data['Pos']) + ":" + str(node.data['Neg']))
for x in node.childs:
print("child " + x)
for x in node.childs:
dot.node(node.childs[x].data['Name'] + ", " + str(node.childs[x].data['Pos']) + ":" + str(node.childs[x].data['Neg']), node.childs[x].data['Name'] + ", " + str(node.childs[x].data['Pos']) + ":" + str(node.childs[x].data['Neg']))
dot.edge(node.data['Name'] + ", " + str(node.data['Pos']) + ":" + str(node.data['Neg']), node.childs[x].data['Name'] + ", " + str(node.childs[x].data['Pos']) + ":" + str(node.childs[x].data['Neg']), label=x)
print_tree_recursive(node.childs[x])
return
def main():
tree = build_tree(read_csv())
print_tree(tree)
print dot.source
dot.render('tree.gv', view=True)
if __name__ == "__main__": main()