一、決策樹簡介
決策樹(Decision Tree)是一種常見的分類和回歸演算法,其可處理離散型和連續型數據,在數據挖掘、機器學習等領域被廣泛應用。
決策樹的結構類似一棵樹,每個節點表示一個屬性,葉子節點表示一個類別(或回歸值),在決策時沿著樹結構前進,並根據節點所表示的屬性值進行選擇,直至到達葉子節點。
二、決策樹構建
決策樹的構建過程包括分裂屬性的選擇和樹的剪枝兩個主要環節。
屬性選擇
屬性選擇的目標是找到最優屬性,使其能夠按照屬性值將訓練集中的樣本劃分到正確的類別中,通常使用信息增益(Information Gain)或信息增益比(Information Gain Ratio)來選擇劃分屬性。
信息增益的計算公式如下:
def calc_information_gain(Y, X): info_D = calc_entropy(Y) # 計算數據集的熵 m = X.shape[1] # 特徵數 info_Dv = np.zeros((m, 1)) for i in range(m): # 計算按照第i個特徵劃分後的條件熵 Dv = split_dataset(X, Y, i) info_Dv[i] = calc_cond_entropy(Y, Dv) gain = info_D - info_Dv # 計算信息增益 return gain
其中calc_entropy(Y)計算數據集的熵,calc_cond_entropy(Y, Dv)計算按照某個特徵劃分後的條件熵,split_dataset(X, Y, i)將數據集按照第i個特徵劃分。
樹的剪枝
決策樹的剪枝是為了防止過擬合,採用預剪枝和後剪枝兩種方法。
預剪枝是在決策樹構建過程中,限制樹的大小來防止過擬合,例如限制樹的深度或葉子節點的最小樣本數。
後剪枝是在決策樹構建完成後,對決策樹進行剪枝來降低過擬合風險,常用的剪枝方法包括C4.5和CART演算法。
三、Python實現決策樹
數據集
我們以鳶尾花數據集為例,數據集包含150個樣本,每個樣本包含4個特徵,分別是花萼長度(sepal_length)、花萼寬度(sepal_width)、花瓣長度(petal_length)和花瓣寬度(petal_width),以及類別(setosa、versicolor和virginica)。
首先導入數據集:
import pandas as pd df = pd.read_csv('iris.csv') X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].to_numpy() Y = df['species'].to_numpy()
將數據集劃分為訓練集和測試集:
from sklearn.model_selection import train_test_split X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
決策樹演算法
首先定義決策樹節點的類(tree_node.py):
class TreeNode: def __init__(self, feature_idx=None, threshold=None, mode=None, leaf=False): self.feature_idx = feature_idx # 特徵索引 self.threshold = threshold # 特徵閾值 self.mode = mode # 類別(葉子節點才會有) self.leaf = leaf # 是否為葉子節點 self.children = {} # 子節點列表 def split(self, X, Y, criterion='entropy'): # 根據信息增益選擇最優特徵 if criterion == 'entropy': gain = calc_information_gain(Y, X) else: # criterion == 'gini' gain = calc_gini_gain(Y, X) i = np.argmax(gain) # 按照最優特徵分裂數據集 Dv = split_dataset(X, Y, i) # 如果信息增益為0,返回當前節點 if np.array_equal(gain, np.zeros((X.shape[1], 1))): return self # 創建子節點並遞歸構建子樹 for k, v in Dv.items(): node = TreeNode(leaf=(len(v) == 1)) if not node.leaf: # 非葉子節點 node = node.split(X[v], Y[v], criterion=criterion) # 遞歸構建子樹 else: # 葉子節點 node.mode = Y[v][0] # 葉子節點為樣本中最普遍的類別 self.children[k] = node self.feature_idx = i return self
定義決策樹的類(decision_tree.py):
class DecisionTree: def __init__(self, criterion='entropy', max_depth=None, min_samples_leaf=1): self.root = None # 決策樹根節點 self.criterion = criterion # 劃分標準:'entropy'或'gini' self.max_depth = max_depth # 最大深度 self.min_samples_leaf = min_samples_leaf # 葉節點最小樣本數 def fit(self, X, Y): self.root = TreeNode(leaf=(len(Y) == 1)) if not self.root.leaf: # 非葉子節點 self.root = self.root.split(X, Y, criterion=self.criterion) # 構建樹 else: # 葉子節點 self.root.mode = Y[0] # 葉子節點為樣本中最普遍的類別 def predict(self, X): Y_pred = np.array([], dtype=int) for x in X: node = self.root while not node.leaf: i = node.feature_idx if x[i] <= node.threshold: node = node.children[0] else: node = node.children[1] Y_pred = np.append(Y_pred, node.mode) return Y_pred
應用決策樹
構建決策樹並對數據進行分類(main.py):
tree = DecisionTree(max_depth=3) tree.fit(X_train, Y_train) Y_pred = tree.predict(X_test)
計算分類精度,並繪製決策樹(plot_tree.py):
from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt from matplotlib.patches import FancyArrowPatch def plot_node(node_text, center_pt, parent_pt, node_type): arrow_args = dict(arrowstyle=" max_depth: max_depth = depth return max_depth def plot_tree(tree): leaf_nodes_count = get_leafs_count(tree.root) tree_depth = get_tree_depth(tree.root) axprops = dict(xticks=[], yticks=[]) create_plot.ax1 = plt.subplot(111, frameon=False, **axprops) plot_node(str(tree.root.mode), (0.5, 0.8), (0.5, 1.0), 'root') plot_tree_node(tree.root, (0.5, 0.8), leaf_nodes_count, tree_depth) plt.axis('off') plt.show() def plot_tree_node(tree_node, parent_center, leaf_nodes_count, tree_depth): if tree_node.leaf: # 葉子節點 return h_unit = 1.0 / tree_depth v_unit = 1.0 / leaf_nodes_count height = 0 for k, child in tree_node.children.items(): center = (parent_center[0] - v_unit, parent_center[1] - h_unit) plot_arrow(center, parent_center, '<-') if child.leaf: # 葉子節點 node_type = {'fc': '0.8', 'ec': 'black', 'boxstyle': 'round'} node_text = str(child.mode) else: # 非葉子節點 node_type = {'fc': '0.8', 'ec': 'black'} node_text = 'X[{}] <= {:.2f}'.format(child.feature_idx, child.threshold) plot_node(node_text, center, parent_center, node_type) plot_tree_node(child, center, leaf_nodes_count, tree_depth) height += 1
運行結果如下:
accuracy: 0.9778
四、小結
本文介紹了Python實現決策樹的方法,包括決策樹的構建過程、屬性選擇、樹的剪枝方法等,同時提供了完整的代碼實現。在實際應用中,決策樹演算法可以用於分類和回歸等場景,並且能夠處理離散型和連續型數據,是數據挖掘、機器學習等領域的重要演算法。
原創文章,作者:WGDBI,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/370587.html