본문 바로가기

AI

1209 - 의사결정나무

728x90

딥러닝 기초 - DecisionTree & RandomForest


의사결정나무

  • 정보 이득 : '부모 노드'와 '자식 노드'의 불순도 합의 차이
    • 자식 노드의 불순도가 낮을수록 정보 이득이 커짐
  • 불순도 지표 : 지니 불순도, 엔트로피, 분류 오차
    • 불순도 조건을 바꾸는 것보다 가지치기 수준을 바꾸면서 튜닝하는게 성적 향상에 도움됨
- 설명이 중요할 때 아주 유용한 모델 
- 뿌리부터 정보 이득이 최대가 되는 특성으로 가지를 나눔
- 모든 이파리가 순수해질 때까지 반복
- 이 과정에서 과대적합될 가능성 高 → 따라서 최대 나무 가지 수를 제한 하는 등 가지치기 과정 필요

 

세가지 불순도 기준을 시각적으로 비교

import matplotlib.pyplot as plt
import numpy as np

def gini(p):
    return (p) * (1 - (p)) + (1 - p)*(1 - (1 - p))

def entropy(p):
    return - p * np.log2(p) - (1 - p)*np.log2((1 - p))

def error(p):
    return 1 - np.max([p, 1 - p])

x = np.arange(0.0, 1.0, 0.01)

ent = [entropy(p) if p != 0 else None for p in x]
sc_ent = [e * 0.5 if e else None for e in ent] # scaled_entropy = 1/2 ent
err = [error(i) for i in x]

fig = plt.figure()
ax = plt.subplot(111)

# impurity, label, line_style, color
for i, lab, ls, c, in zip([ent, sc_ent, gini(x), err], 
                          ['Entropy', 'Entropy (scaled)', 'Gini impurity', 'Misclassification error'],
                          ['-', '-', '--', '-.'], 
                          ['black', 'lightgray', 'red', 'green', 'cyan']):
    line = ax.plot(x, i, label=lab, linestyle=ls, lw=2, color=c)

ax.legend(loc = 'upper center', bbox_to_anchor = (0.5, 1.15), ncol = 5, fancybox = True, shadow = False)
ax.axhline(y = 0.5, linewidth = 1, color = 'k', linestyle = '--')
ax.axhline(y = 1.0, linewidth = 1, color = 'k', linestyle = '--')
plt.ylim([0, 1.1])
plt.xlabel('p(i=1)')
plt.ylabel('impurity index')
plt.show()

'지니 불순도'는 '엔트로피'와 '분류 오차'의 중간

 

 

트리 만들기

from sklearn.tree import DecisionTreeClassifier

tree_model = DecisionTreeClassifier(criterion = 'gini', 
                              max_depth = 4, 
                              random_state = 1)

tree_model.fit(X_train, y_train)
X_combined = np.vstack((X_train, X_test))
y_combined = np.hstack((y_train, y_test))

plot_decision_regions(X_combined, 
                      y_combined, 
                      classifier=tree_model, 
                      test_idx=range(105, 150))

plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc = 'upper left')
plt.tight_layout()
plt.show()

결정 트리의 결정 경계

 

그림으로 의사결정나무 과정 확인

# sklearn에서 제공하는 plot_tree 기능 이용
from sklearn import tree
tree.plot_tree(tree_model)
plt.show()

plot_tree로 그린 결정 트리

 

# 그래프비즈와 파이닷을 이용해 그림 확인

from pydotplus import graph_from_dot_data
from sklearn.tree import export_graphviz

dot_data = export_graphviz(tree_model,
                           filled = True,
                           class_names = ['Setosa',
                                          'Versicolor',
                                          'Virginica'],
                           feature_names = ['petal length',
                                            'petal width'],
                           out_file = None)

graph = graph_from_dot_data(dot_data)
graph.write_png('tree_model.png')

graphviz와 pydotplus 이용

 

앙상블 - 랜덤 포레스트

  • 뛰어난 분류 성능과 과대적합에 안정적
  • 결정 트리의 앙상블
    • 여러 개의 결정 트리를 평균 내는 것
    • 개개의 트리는 분산이 높지만 앙상블을 통해 성능을 높이고 과대적합의 위험을 줄임
from sklearn.ensemble import RandomForestClassifier

forest = RandomForestClassifier(criterion = 'gini',
                                n_estimators = 25,
                                random_state = 1,
                                n_jobs = -1)
forest.fit(X_train, y_train)
plot_decision_regions(X_combined, 
                      y_combined, 
                      classifier=forest, 
                      test_idx=range(105, 150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc = 'upper left')
plt.tight_layout()
plt.show()

'AI' 카테고리의 다른 글

머신러닝 기초 - 읽어보기  (0) 2021.12.09
1208 - 서포트백터머신  (0) 2021.12.08
1207 - 로지스틱회귀  (0) 2021.12.07
1207 - 아달린 SGD  (0) 2021.12.07
1202 - 아달린  (0) 2021.12.02