Python實現決策樹

一、決策樹簡介

決策樹(Decision Tree)是一種常見的分類和回歸演算法,其可處理離散型和連續型數據,在數據挖掘、機器學習等領域被廣泛應用。

決策樹的結構類似一棵樹,每個節點表示一個屬性,葉子節點表示一個類別(或回歸值),在決策時沿著樹結構前進,並根據節點所表示的屬性值進行選擇,直至到達葉子節點。

二、決策樹構建

決策樹的構建過程包括分裂屬性的選擇和樹的剪枝兩個主要環節。

屬性選擇

屬性選擇的目標是找到最優屬性,使其能夠按照屬性值將訓練集中的樣本劃分到正確的類別中,通常使用信息增益(Information Gain)或信息增益比(Information Gain Ratio)來選擇劃分屬性。

信息增益的計算公式如下:

def calc_information_gain(Y, X):
    info_D = calc_entropy(Y)  # 計算數據集的熵
    m = X.shape[1]  # 特徵數
    info_Dv = np.zeros((m, 1))
    for i in range(m):
        # 計算按照第i個特徵劃分後的條件熵
        Dv = split_dataset(X, Y, i)
        info_Dv[i] = calc_cond_entropy(Y, Dv)
    gain = info_D - info_Dv  # 計算信息增益
    return gain

其中calc_entropy(Y)計算數據集的熵,calc_cond_entropy(Y, Dv)計算按照某個特徵劃分後的條件熵,split_dataset(X, Y, i)將數據集按照第i個特徵劃分。

樹的剪枝

決策樹的剪枝是為了防止過擬合,採用預剪枝和後剪枝兩種方法。

預剪枝是在決策樹構建過程中,限制樹的大小來防止過擬合,例如限制樹的深度或葉子節點的最小樣本數。

後剪枝是在決策樹構建完成後,對決策樹進行剪枝來降低過擬合風險,常用的剪枝方法包括C4.5和CART演算法。

三、Python實現決策樹

數據集

我們以鳶尾花數據集為例,數據集包含150個樣本,每個樣本包含4個特徵,分別是花萼長度(sepal_length)、花萼寬度(sepal_width)、花瓣長度(petal_length)和花瓣寬度(petal_width),以及類別(setosa、versicolor和virginica)。

首先導入數據集:

import pandas as pd

df = pd.read_csv('iris.csv')
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].to_numpy()
Y = df['species'].to_numpy()

將數據集劃分為訓練集和測試集:

from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=0)

決策樹演算法

首先定義決策樹節點的類(tree_node.py):

class TreeNode:
    def __init__(self, feature_idx=None, threshold=None, mode=None, leaf=False):
        self.feature_idx = feature_idx  # 特徵索引
        self.threshold = threshold  # 特徵閾值
        self.mode = mode  # 類別(葉子節點才會有)
        self.leaf = leaf  # 是否為葉子節點
        self.children = {}  # 子節點列表

    def split(self, X, Y, criterion='entropy'):
        # 根據信息增益選擇最優特徵
        if criterion == 'entropy':
            gain = calc_information_gain(Y, X)
        else:  # criterion == 'gini'
            gain = calc_gini_gain(Y, X)
        i = np.argmax(gain)
        # 按照最優特徵分裂數據集
        Dv = split_dataset(X, Y, i)
        # 如果信息增益為0,返回當前節點
        if np.array_equal(gain, np.zeros((X.shape[1], 1))):
            return self
        # 創建子節點並遞歸構建子樹
        for k, v in Dv.items():
            node = TreeNode(leaf=(len(v) == 1))
            if not node.leaf:  # 非葉子節點
                node = node.split(X[v], Y[v], criterion=criterion)  # 遞歸構建子樹
            else:  # 葉子節點
                node.mode = Y[v][0]  # 葉子節點為樣本中最普遍的類別
            self.children[k] = node
            self.feature_idx = i
        return self

定義決策樹的類(decision_tree.py):

class DecisionTree:
    def __init__(self, criterion='entropy', max_depth=None, min_samples_leaf=1):
        self.root = None  # 決策樹根節點
        self.criterion = criterion  # 劃分標準:'entropy'或'gini'
        self.max_depth = max_depth  # 最大深度
        self.min_samples_leaf = min_samples_leaf  # 葉節點最小樣本數

    def fit(self, X, Y):
        self.root = TreeNode(leaf=(len(Y) == 1))
        if not self.root.leaf:  # 非葉子節點
            self.root = self.root.split(X, Y, criterion=self.criterion)  # 構建樹
        else:  # 葉子節點
            self.root.mode = Y[0]  # 葉子節點為樣本中最普遍的類別

    def predict(self, X):
        Y_pred = np.array([], dtype=int)
        for x in X:
            node = self.root
            while not node.leaf:
                i = node.feature_idx
                if x[i] <= node.threshold:
                    node = node.children[0]
                else:
                    node = node.children[1]
            Y_pred = np.append(Y_pred, node.mode)
        return Y_pred

應用決策樹

構建決策樹並對數據進行分類(main.py):

tree = DecisionTree(max_depth=3)
tree.fit(X_train, Y_train)
Y_pred = tree.predict(X_test)

計算分類精度,並繪製決策樹(plot_tree.py):

from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch

def plot_node(node_text, center_pt, parent_pt, node_type):
    arrow_args = dict(arrowstyle=" max_depth:
            max_depth = depth
    return max_depth

def plot_tree(tree):
    leaf_nodes_count = get_leafs_count(tree.root)
    tree_depth = get_tree_depth(tree.root)
    axprops = dict(xticks=[], yticks=[])
    create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plot_node(str(tree.root.mode), (0.5, 0.8), (0.5, 1.0), 'root')
    plot_tree_node(tree.root, (0.5, 0.8), leaf_nodes_count, tree_depth)
    plt.axis('off')
    plt.show()

def plot_tree_node(tree_node, parent_center, leaf_nodes_count, tree_depth):
    if tree_node.leaf:  # 葉子節點
        return
    h_unit = 1.0 / tree_depth
    v_unit = 1.0 / leaf_nodes_count
    height = 0
    for k, child in tree_node.children.items():
        center = (parent_center[0] - v_unit, parent_center[1] - h_unit)
        plot_arrow(center, parent_center, '<-')
        if child.leaf:  # 葉子節點
            node_type = {'fc': '0.8', 'ec': 'black', 'boxstyle': 'round'}
            node_text = str(child.mode)
        else:  # 非葉子節點
            node_type = {'fc': '0.8', 'ec': 'black'}
            node_text = 'X[{}] <= {:.2f}'.format(child.feature_idx, child.threshold)
        plot_node(node_text, center, parent_center, node_type)
        plot_tree_node(child, center, leaf_nodes_count, tree_depth)
        height += 1

運行結果如下:

accuracy: 0.9778

四、小結

本文介紹了Python實現決策樹的方法,包括決策樹的構建過程、屬性選擇、樹的剪枝方法等,同時提供了完整的代碼實現。在實際應用中,決策樹演算法可以用於分類和回歸等場景,並且能夠處理離散型和連續型數據,是數據挖掘、機器學習等領域的重要演算法。

原創文章,作者:WGDBI,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/370587.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
WGDBI的頭像WGDBI
上一篇 2025-04-22 01:14
下一篇 2025-04-22 01:14

相關推薦

  • Python列表中負數的個數

    Python列表是一個有序的集合,可以存儲多個不同類型的元素。而負數是指小於0的整數。在Python列表中,我們想要找到負數的個數,可以通過以下幾個方面進行實現。 一、使用循環遍歷…

    編程 2025-04-29
  • 如何查看Anaconda中Python路徑

    對Anaconda中Python路徑即conda環境的查看進行詳細的闡述。 一、使用命令行查看 1、在Windows系統中,可以使用命令提示符(cmd)或者Anaconda Pro…

    編程 2025-04-29
  • Python中引入上一級目錄中函數

    Python中經常需要調用其他文件夾中的模塊或函數,其中一個常見的操作是引入上一級目錄中的函數。在此,我們將從多個角度詳細解釋如何在Python中引入上一級目錄的函數。 一、加入環…

    編程 2025-04-29
  • Python計算陽曆日期對應周幾

    本文介紹如何通過Python計算任意陽曆日期對應周幾。 一、獲取日期 獲取日期可以通過Python內置的模塊datetime實現,示例代碼如下: from datetime imp…

    編程 2025-04-29
  • Python周杰倫代碼用法介紹

    本文將從多個方面對Python周杰倫代碼進行詳細的闡述。 一、代碼介紹 from urllib.request import urlopen from bs4 import Bea…

    編程 2025-04-29
  • Python字典去重複工具

    使用Python語言編寫字典去重複工具,可幫助用戶快速去重複。 一、字典去重複工具的需求 在使用Python編寫程序時,我們經常需要處理數據文件,其中包含了大量的重複數據。為了方便…

    編程 2025-04-29
  • Python清華鏡像下載

    Python清華鏡像是一個高質量的Python開發資源鏡像站,提供了Python及其相關的開發工具、框架和文檔的下載服務。本文將從以下幾個方面對Python清華鏡像下載進行詳細的闡…

    編程 2025-04-29
  • python強行終止程序快捷鍵

    本文將從多個方面對python強行終止程序快捷鍵進行詳細闡述,並提供相應代碼示例。 一、Ctrl+C快捷鍵 Ctrl+C快捷鍵是在終端中經常用來強行終止運行的程序。當你在終端中運行…

    編程 2025-04-29
  • 蝴蝶優化演算法Python版

    蝴蝶優化演算法是一種基於仿生學的優化演算法,模仿自然界中的蝴蝶進行搜索。它可以應用於多個領域的優化問題,包括數學優化、工程問題、機器學習等。本文將從多個方面對蝴蝶優化演算法Python版…

    編程 2025-04-29
  • Python程序需要編譯才能執行

    Python 被廣泛應用於數據分析、人工智慧、科學計算等領域,它的靈活性和簡單易學的性質使得越來越多的人喜歡使用 Python 進行編程。然而,在 Python 中程序執行的方式不…

    編程 2025-04-29

發表回復

登錄後才能評論