一、决策树简介
决策树(Decision Tree)是一种常见的分类和回归算法,其可处理离散型和连续型数据,在数据挖掘、机器学习等领域被广泛应用。
决策树的结构类似一棵树,每个节点表示一个属性,叶子节点表示一个类别(或回归值),在决策时沿着树结构前进,并根据节点所表示的属性值进行选择,直至到达叶子节点。
二、决策树构建
决策树的构建过程包括分裂属性的选择和树的剪枝两个主要环节。
属性选择
属性选择的目标是找到最优属性,使其能够按照属性值将训练集中的样本划分到正确的类别中,通常使用信息增益(Information Gain)或信息增益比(Information Gain Ratio)来选择划分属性。
信息增益的计算公式如下:
def calc_information_gain(Y, X):
info_D = calc_entropy(Y) # 计算数据集的熵
m = X.shape[1] # 特征数
info_Dv = np.zeros((m, 1))
for i in range(m):
# 计算按照第i个特征划分后的条件熵
Dv = split_dataset(X, Y, i)
info_Dv[i] = calc_cond_entropy(Y, Dv)
gain = info_D - info_Dv # 计算信息增益
return gain
其中calc_entropy(Y)计算数据集的熵,calc_cond_entropy(Y, Dv)计算按照某个特征划分后的条件熵,split_dataset(X, Y, i)将数据集按照第i个特征划分。
树的剪枝
决策树的剪枝是为了防止过拟合,采用预剪枝和后剪枝两种方法。
预剪枝是在决策树构建过程中,限制树的大小来防止过拟合,例如限制树的深度或叶子节点的最小样本数。
后剪枝是在决策树构建完成后,对决策树进行剪枝来降低过拟合风险,常用的剪枝方法包括C4.5和CART算法。
三、Python实现决策树
数据集
我们以鸢尾花数据集为例,数据集包含150个样本,每个样本包含4个特征,分别是花萼长度(sepal_length)、花萼宽度(sepal_width)、花瓣长度(petal_length)和花瓣宽度(petal_width),以及类别(setosa、versicolor和virginica)。
首先导入数据集:
import pandas as pd
df = pd.read_csv('iris.csv')
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].to_numpy()
Y = df['species'].to_numpy()
将数据集划分为训练集和测试集:
from sklearn.model_selection import train_test_split X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)
决策树算法
首先定义决策树节点的类(tree_node.py):
class TreeNode:
def __init__(self, feature_idx=None, threshold=None, mode=None, leaf=False):
self.feature_idx = feature_idx # 特征索引
self.threshold = threshold # 特征阈值
self.mode = mode # 类别(叶子节点才会有)
self.leaf = leaf # 是否为叶子节点
self.children = {} # 子节点列表
def split(self, X, Y, criterion='entropy'):
# 根据信息增益选择最优特征
if criterion == 'entropy':
gain = calc_information_gain(Y, X)
else: # criterion == 'gini'
gain = calc_gini_gain(Y, X)
i = np.argmax(gain)
# 按照最优特征分裂数据集
Dv = split_dataset(X, Y, i)
# 如果信息增益为0,返回当前节点
if np.array_equal(gain, np.zeros((X.shape[1], 1))):
return self
# 创建子节点并递归构建子树
for k, v in Dv.items():
node = TreeNode(leaf=(len(v) == 1))
if not node.leaf: # 非叶子节点
node = node.split(X[v], Y[v], criterion=criterion) # 递归构建子树
else: # 叶子节点
node.mode = Y[v][0] # 叶子节点为样本中最普遍的类别
self.children[k] = node
self.feature_idx = i
return self
定义决策树的类(decision_tree.py):
class DecisionTree:
def __init__(self, criterion='entropy', max_depth=None, min_samples_leaf=1):
self.root = None # 决策树根节点
self.criterion = criterion # 划分标准:'entropy'或'gini'
self.max_depth = max_depth # 最大深度
self.min_samples_leaf = min_samples_leaf # 叶节点最小样本数
def fit(self, X, Y):
self.root = TreeNode(leaf=(len(Y) == 1))
if not self.root.leaf: # 非叶子节点
self.root = self.root.split(X, Y, criterion=self.criterion) # 构建树
else: # 叶子节点
self.root.mode = Y[0] # 叶子节点为样本中最普遍的类别
def predict(self, X):
Y_pred = np.array([], dtype=int)
for x in X:
node = self.root
while not node.leaf:
i = node.feature_idx
if x[i] <= node.threshold:
node = node.children[0]
else:
node = node.children[1]
Y_pred = np.append(Y_pred, node.mode)
return Y_pred
应用决策树
构建决策树并对数据进行分类(main.py):
tree = DecisionTree(max_depth=3) tree.fit(X_train, Y_train) Y_pred = tree.predict(X_test)
计算分类精度,并绘制决策树(plot_tree.py):
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
def plot_node(node_text, center_pt, parent_pt, node_type):
arrow_args = dict(arrowstyle=" max_depth:
max_depth = depth
return max_depth
def plot_tree(tree):
leaf_nodes_count = get_leafs_count(tree.root)
tree_depth = get_tree_depth(tree.root)
axprops = dict(xticks=[], yticks=[])
create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
plot_node(str(tree.root.mode), (0.5, 0.8), (0.5, 1.0), 'root')
plot_tree_node(tree.root, (0.5, 0.8), leaf_nodes_count, tree_depth)
plt.axis('off')
plt.show()
def plot_tree_node(tree_node, parent_center, leaf_nodes_count, tree_depth):
if tree_node.leaf: # 叶子节点
return
h_unit = 1.0 / tree_depth
v_unit = 1.0 / leaf_nodes_count
height = 0
for k, child in tree_node.children.items():
center = (parent_center[0] - v_unit, parent_center[1] - h_unit)
plot_arrow(center, parent_center, '<-')
if child.leaf: # 叶子节点
node_type = {'fc': '0.8', 'ec': 'black', 'boxstyle': 'round'}
node_text = str(child.mode)
else: # 非叶子节点
node_type = {'fc': '0.8', 'ec': 'black'}
node_text = 'X[{}] <= {:.2f}'.format(child.feature_idx, child.threshold)
plot_node(node_text, center, parent_center, node_type)
plot_tree_node(child, center, leaf_nodes_count, tree_depth)
height += 1
运行结果如下:
accuracy: 0.9778
四、小结
本文介绍了Python实现决策树的方法,包括决策树的构建过程、属性选择、树的剪枝方法等,同时提供了完整的代码实现。在实际应用中,决策树算法可以用于分类和回归等场景,并且能够处理离散型和连续型数据,是数据挖掘、机器学习等领域的重要算法。
原创文章,作者:WGDBI,如若转载,请注明出处:https://www.506064.com/n/370587.html
微信扫一扫
支付宝扫一扫