一、決策樹簡介
決策樹(Decision Tree)是一種常見的分類和回歸演算法,其可處理離散型和連續型數據,在數據挖掘、機器學習等領域被廣泛應用。
決策樹的結構類似一棵樹,每個節點表示一個屬性,葉子節點表示一個類別(或回歸值),在決策時沿著樹結構前進,並根據節點所表示的屬性值進行選擇,直至到達葉子節點。
二、決策樹構建
決策樹的構建過程包括分裂屬性的選擇和樹的剪枝兩個主要環節。
屬性選擇
屬性選擇的目標是找到最優屬性,使其能夠按照屬性值將訓練集中的樣本劃分到正確的類別中,通常使用信息增益(Information Gain)或信息增益比(Information Gain Ratio)來選擇劃分屬性。
信息增益的計算公式如下:
def calc_information_gain(Y, X):
info_D = calc_entropy(Y) # 計算數據集的熵
m = X.shape[1] # 特徵數
info_Dv = np.zeros((m, 1))
for i in range(m):
# 計算按照第i個特徵劃分後的條件熵
Dv = split_dataset(X, Y, i)
info_Dv[i] = calc_cond_entropy(Y, Dv)
gain = info_D - info_Dv # 計算信息增益
return gain
其中calc_entropy(Y)計算數據集的熵,calc_cond_entropy(Y, Dv)計算按照某個特徵劃分後的條件熵,split_dataset(X, Y, i)將數據集按照第i個特徵劃分。
樹的剪枝
決策樹的剪枝是為了防止過擬合,採用預剪枝和後剪枝兩種方法。
預剪枝是在決策樹構建過程中,限制樹的大小來防止過擬合,例如限制樹的深度或葉子節點的最小樣本數。
後剪枝是在決策樹構建完成後,對決策樹進行剪枝來降低過擬合風險,常用的剪枝方法包括C4.5和CART演算法。
三、Python實現決策樹
數據集
我們以鳶尾花數據集為例,數據集包含150個樣本,每個樣本包含4個特徵,分別是花萼長度(sepal_length)、花萼寬度(sepal_width)、花瓣長度(petal_length)和花瓣寬度(petal_width),以及類別(setosa、versicolor和virginica)。
首先導入數據集:
import pandas as pd
df = pd.read_csv('iris.csv')
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].to_numpy()
Y = df['species'].to_numpy()
將數據集劃分為訓練集和測試集:
from sklearn.model_selection import train_test_split X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
決策樹演算法
首先定義決策樹節點的類(tree_node.py):
class TreeNode:
def __init__(self, feature_idx=None, threshold=None, mode=None, leaf=False):
self.feature_idx = feature_idx # 特徵索引
self.threshold = threshold # 特徵閾值
self.mode = mode # 類別(葉子節點才會有)
self.leaf = leaf # 是否為葉子節點
self.children = {} # 子節點列表
def split(self, X, Y, criterion='entropy'):
# 根據信息增益選擇最優特徵
if criterion == 'entropy':
gain = calc_information_gain(Y, X)
else: # criterion == 'gini'
gain = calc_gini_gain(Y, X)
i = np.argmax(gain)
# 按照最優特徵分裂數據集
Dv = split_dataset(X, Y, i)
# 如果信息增益為0,返回當前節點
if np.array_equal(gain, np.zeros((X.shape[1], 1))):
return self
# 創建子節點並遞歸構建子樹
for k, v in Dv.items():
node = TreeNode(leaf=(len(v) == 1))
if not node.leaf: # 非葉子節點
node = node.split(X[v], Y[v], criterion=criterion) # 遞歸構建子樹
else: # 葉子節點
node.mode = Y[v][0] # 葉子節點為樣本中最普遍的類別
self.children[k] = node
self.feature_idx = i
return self
定義決策樹的類(decision_tree.py):
class DecisionTree:
def __init__(self, criterion='entropy', max_depth=None, min_samples_leaf=1):
self.root = None # 決策樹根節點
self.criterion = criterion # 劃分標準:'entropy'或'gini'
self.max_depth = max_depth # 最大深度
self.min_samples_leaf = min_samples_leaf # 葉節點最小樣本數
def fit(self, X, Y):
self.root = TreeNode(leaf=(len(Y) == 1))
if not self.root.leaf: # 非葉子節點
self.root = self.root.split(X, Y, criterion=self.criterion) # 構建樹
else: # 葉子節點
self.root.mode = Y[0] # 葉子節點為樣本中最普遍的類別
def predict(self, X):
Y_pred = np.array([], dtype=int)
for x in X:
node = self.root
while not node.leaf:
i = node.feature_idx
if x[i] <= node.threshold:
node = node.children[0]
else:
node = node.children[1]
Y_pred = np.append(Y_pred, node.mode)
return Y_pred
應用決策樹
構建決策樹並對數據進行分類(main.py):
tree = DecisionTree(max_depth=3) tree.fit(X_train, Y_train) Y_pred = tree.predict(X_test)
計算分類精度,並繪製決策樹(plot_tree.py):
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
def plot_node(node_text, center_pt, parent_pt, node_type):
arrow_args = dict(arrowstyle=" max_depth:
max_depth = depth
return max_depth
def plot_tree(tree):
leaf_nodes_count = get_leafs_count(tree.root)
tree_depth = get_tree_depth(tree.root)
axprops = dict(xticks=[], yticks=[])
create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
plot_node(str(tree.root.mode), (0.5, 0.8), (0.5, 1.0), 'root')
plot_tree_node(tree.root, (0.5, 0.8), leaf_nodes_count, tree_depth)
plt.axis('off')
plt.show()
def plot_tree_node(tree_node, parent_center, leaf_nodes_count, tree_depth):
if tree_node.leaf: # 葉子節點
return
h_unit = 1.0 / tree_depth
v_unit = 1.0 / leaf_nodes_count
height = 0
for k, child in tree_node.children.items():
center = (parent_center[0] - v_unit, parent_center[1] - h_unit)
plot_arrow(center, parent_center, '<-')
if child.leaf: # 葉子節點
node_type = {'fc': '0.8', 'ec': 'black', 'boxstyle': 'round'}
node_text = str(child.mode)
else: # 非葉子節點
node_type = {'fc': '0.8', 'ec': 'black'}
node_text = 'X[{}] <= {:.2f}'.format(child.feature_idx, child.threshold)
plot_node(node_text, center, parent_center, node_type)
plot_tree_node(child, center, leaf_nodes_count, tree_depth)
height += 1
運行結果如下:
accuracy: 0.9778
四、小結
本文介紹了Python實現決策樹的方法,包括決策樹的構建過程、屬性選擇、樹的剪枝方法等,同時提供了完整的代碼實現。在實際應用中,決策樹演算法可以用於分類和回歸等場景,並且能夠處理離散型和連續型數據,是數據挖掘、機器學習等領域的重要演算法。
原創文章,作者:WGDBI,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/370587.html
微信掃一掃
支付寶掃一掃