# Decision Tree

 ``` 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122``` ```__author__ = 'Const' import os import numpy as np import numexpr from graphviz import Digraph from pandas import read_table from math import log from collections import defaultdict dot = Digraph(comment='TREE') def H(*values): values = np.asarray(values) * 1.0 values = values / sum(values) return sum((-value * log(value, 2) for value in values if value > 0.0)) def calc_pn(p, n, P, N): return ( H(P, N), H(p, n), H(P - p, N - n), H(P, N) - (p + n) * 1.0 / (P + N) * H(p, n) - (P + N - p - n) * 1.0 / (P + N) * H(P - p, N - n), ) # return H(P, N) - (p + n) * 1.0 / (P + N) * H(p, n) - (P + N - p - n) * 1.0 / (P + N) * H(P - p, N - n) def read_csv(): return read_table(open('./train.csv', 'r'), header=0, sep=r",\s*", index_col=False) class Tree: def __init__(self): self.condition = "<=50K" self.data = {} self.childs = {} def __str__(self): return str(self.condition) def build_tree(df): tree = Tree() return build_tree_recursive(df, tree) def build_tree_recursive(df, node): val = dict(df.groupby(['Class'])['Class'].count()) P = val['>50K'] if '>50K' in val else 0 N = val['<=50K'] if '<=50K' in val else 0 gain = {} availible_features = list(df.columns.values[:-1]); max_column = df.columns.values[0] for column in df.columns.values[:-1]: features = defaultdict(lambda: 0) features.update(dict(df.groupby([column, 'Class'])['Class'].count())) val = set([x[0] for x in features.keys()]) sub_gain = [] for x in val: sub_gain.append(calc_pn(features[(x, '>50K')], features[(x, '<=50K')], P, N)[3]) gain[column] = max(sub_gain) max_column = column if gain[column] > gain[max_column] else max_column for x in dict(df.groupby(max_column)[max_column].count()).keys(): if len(df.query(max_column + " == '" + str(x) + "'").columns.values) == 1 or x == "?": continue node.childs[x] = Tree() query_to_go = df.query(max_column + " == '" + str(x) + "'") value = dict(query_to_go.groupby(['Class'])['Class'].count()) p = value['>50K'] if '>50K' in value else 0 n = value['<=50K'] if '<=50K' in value else 0 #node.childs[x].data['Name'] = max_column + " == '" + x + "'" if calc_pn(p, n, P, N)[2] == 1.0 or p == 0 or n == 0: node.childs[x].condition = '>50K' if p > n else '<=50K' node.childs[x].data['Name'] = x node.childs[x].data['Pos'] = p node.childs[x].data['Neg'] = n continue #print(max_column + " == '" + x + "'" + " " + str(calc_pn(p, n, P, N))) #print(query_to_go) del query_to_go[max_column] #print type(df.query(max_column + " == '" + x + "'")) build_tree_recursive(query_to_go, node.childs[x]) node.data['Name'] = max_column node.data['Pos'] = P node.data['Neg'] = N return node def print_tree(tree): print_tree_recursive(tree) def print_tree_recursive(node): if node.data['Name']: print("current " + str(node.data['Name']) + " p: " + str(node.data['Pos']) + " n: " + str(node.data['Neg'])) print("childs :") dot.node(str(node.data['Name']) + ", " + str(node.data['Pos']) + ":" + str(node.data['Neg']), str(node.data['Name']) + ", " + str(node.data['Pos']) + ":" + str(node.data['Neg'])) for x in node.childs: print("child " + str(x)) for x in node.childs: dot.node(str(node.childs[x].data['Name']) + ", " + str(node.childs[x].data['Pos']) + ":" + str(node.childs[x].data['Neg']), str(node.childs[x].data['Name']) + ", " + str(node.childs[x].data['Pos']) + ":" + str(node.childs[x].data['Neg'])) dot.edge(str(node.data['Name']) + ", " + str(node.data['Pos']) + ":" + str(node.data['Neg']), str(node.childs[x].data['Name']) + ", " + str(node.childs[x].data['Pos']) + ":" + str(node.childs[x].data['Neg']), label=str(x)) print_tree_recursive(node.childs[x]) return def main(): tree = build_tree(read_csv()) print_tree(tree) print dot.source dot.render('tree.gv', view=True) if __name__ == "__main__": main() ```