用Python繪製混淆矩陣

一、Python繪製混淆矩陣

混淆矩陣是用來對分類器進行評價的一種矩陣,它展示了分類器在測試集上的預測效果。我們可以使用Python來畫出混淆矩陣。

首先需要明確的是,混淆矩陣的行表示實際類別,列表示預測類別。例如,對於一個二分類問題,其中一個類別為陽性,另一個類別為陰性,如果將陽性預測為陰性,則是一個False Negative(FN);如果將陰性預測為陽性,則是一個False Positive(FP)。同理,如果將陽性預測為陽性,則是一個True Positive(TP);如果將陰性預測為陰性,則是一個True Negative(TN)。

|         |   Predicted No    |   Predicted Yes   |
|---------|-------------------|-------------------|
|    No   |        TN         |        FP         |
|   Yes   |        FN         |        TP         |

我們可以使用Python中的sklearn庫來獲取分類器的混淆矩陣。下面是一個代碼示例:

from sklearn.metrics import confusion_matrix

y_true = [0, 1, 0, 1]
y_pred = [1, 1, 0, 0]

confusion_matrix(y_true, y_pred)

運行結果如下:

array([[1, 1],
       [1, 1]])

上方的結果表示,將0(No)預測為1(Yes)的次數為1,將1(Yes)預測為0(No)的次數為1,將0(No)預測為0(No)的次數為1,將1(Yes)預測為1(Yes)的次數為1。

二、混淆矩陣怎麼畫 Python

在了解了混淆矩陣的表示後,我們可以使用Python來畫出混淆矩陣的可視化圖。下方是一個代碼示例:

import numpy as np
import itertools
import matplotlib.pyplot as plt
 
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    將混淆矩陣畫出
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
 
    print(cm)
 
    # 畫圖
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
 
    # 循環遍歷混淆矩陣並填上數字
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
 
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

# 測試數據
y_true = [0, 1, 0, 1]
y_pred = [1, 1, 0, 0]

# 獲取混淆矩陣
cnf_matrix = confusion_matrix(y_true, y_pred)

# 繪製混淆矩陣
plot_confusion_matrix(cnf_matrix, classes=[0, 1])

運行結果如下:

上方的圖表清晰地展示了混淆矩陣的各個元素以及數字標註。

三、Python混淆矩陣

要理解Python繪製混淆矩陣的原理,首先需要知道Python混淆矩陣的運作機制。混淆矩陣是指,將預測分類的結果與測試實際結果進行比較的矩陣。它是在機器學習數學領域中,用于衡量分類器(也稱模型)質量的矩陣。

在Python中,我們可以使用許多不同的Python庫來繪製混淆矩陣。常見的有matplotlib和seaborn等。這些庫提供了一些內置的函數,用於繪製各種矩陣和圖表。其中matplotlib.pyplot是Python中最常用的數據可視化庫之一,因此下文將使用它來完善混淆矩陣的可視化圖表。

四、Python混淆矩陣代碼

下面是一個Python繪製混淆矩陣的代碼示例:

import matplotlib.pyplot as plt
import numpy as np

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    繪製混淆矩陣
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    # 畫圖
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    # 設定橫縱坐標及標籤
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    # 標註數字
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    # 設置其他參數
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

# 測試數據
y_true = [0, 1, 0, 1]
y_pred = [1, 1, 0, 0]

# 獲取混淆矩陣
confusion_matrix = confusion_matrix(y_true, y_pred)

# 繪製混淆矩陣
class_names = ['0', '1'] # 類別名稱
plot_confusion_matrix(confusion_matrix, classes=class_names)

五、Python生成混淆矩陣

要生成Python混淆矩陣,我們需要預測的結果和實際的結果。如果我們善用該矩陣的可視化表示形式,我們就可以清晰地了解分類器的表現,並根據我們的真實數據來更新演算法。

下面是一個Python生成混淆矩陣的代碼示例:

from sklearn.metrics import confusion_matrix
import seaborn as sns

# 獲得測試數據
predictions = [0, 1, 1, 0, 1, 1]
true_classes = [1, 0, 1, 1, 1, 0]

# 獲取混淆矩陣並利用熱力圖展示
cm = confusion_matrix(true_classes, predictions)
sns.heatmap(cm, annot=True, fmt='d', cmap='Reds')

六、Python畫混淆矩陣

現在我們已經掌握了Python生成和繪製混淆矩陣的基本原理,下面這個代碼示例演示了如何用Python畫混淆矩陣:

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import pandas as pd

# 定義畫圖函數
def plot_confusion_matrix(y_true, y_pred, cmap="Blues"):
    
    # 計算混淆矩陣數據
    cm = confusion_matrix(y_true, y_pred)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=['True A', 'True B'], columns=['Pred A', 'Pred B'])
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
 
    # 設置畫布
    fig, ax = plt.subplots(figsize=(2.5, 2.5))
    ax.text(-1.2, 1.2, 'Confusion\nMatrix', fontsize=12, transform=ax.transAxes)

    # 設置其它參數
    sns.heatmap(cm, annot=annot, fmt='', cmap=cmap, ax=ax)

    plt.show()

# 測試數據
actuals = ['A', 'B', 'B', 'A', 'A', 'B']
predicted = ['A', 'B', 'B', 'A', 'A', 'A']

# 畫混淆矩陣
plot_confusion_matrix(actuals, predicted)

上方的混淆矩陣清晰且易於理解,對於評估分類結果具有重要意義。

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/294109.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-12-26 13:15
下一篇 2024-12-26 13:15

相關推薦

  • Python列表中負數的個數

    Python列表是一個有序的集合,可以存儲多個不同類型的元素。而負數是指小於0的整數。在Python列表中,我們想要找到負數的個數,可以通過以下幾個方面進行實現。 一、使用循環遍歷…

    編程 2025-04-29
  • Python周杰倫代碼用法介紹

    本文將從多個方面對Python周杰倫代碼進行詳細的闡述。 一、代碼介紹 from urllib.request import urlopen from bs4 import Bea…

    編程 2025-04-29
  • 如何查看Anaconda中Python路徑

    對Anaconda中Python路徑即conda環境的查看進行詳細的闡述。 一、使用命令行查看 1、在Windows系統中,可以使用命令提示符(cmd)或者Anaconda Pro…

    編程 2025-04-29
  • Python計算陽曆日期對應周幾

    本文介紹如何通過Python計算任意陽曆日期對應周幾。 一、獲取日期 獲取日期可以通過Python內置的模塊datetime實現,示例代碼如下: from datetime imp…

    編程 2025-04-29
  • Python中引入上一級目錄中函數

    Python中經常需要調用其他文件夾中的模塊或函數,其中一個常見的操作是引入上一級目錄中的函數。在此,我們將從多個角度詳細解釋如何在Python中引入上一級目錄的函數。 一、加入環…

    編程 2025-04-29
  • Python字典去重複工具

    使用Python語言編寫字典去重複工具,可幫助用戶快速去重複。 一、字典去重複工具的需求 在使用Python編寫程序時,我們經常需要處理數據文件,其中包含了大量的重複數據。為了方便…

    編程 2025-04-29
  • python強行終止程序快捷鍵

    本文將從多個方面對python強行終止程序快捷鍵進行詳細闡述,並提供相應代碼示例。 一、Ctrl+C快捷鍵 Ctrl+C快捷鍵是在終端中經常用來強行終止運行的程序。當你在終端中運行…

    編程 2025-04-29
  • 蝴蝶優化演算法Python版

    蝴蝶優化演算法是一種基於仿生學的優化演算法,模仿自然界中的蝴蝶進行搜索。它可以應用於多個領域的優化問題,包括數學優化、工程問題、機器學習等。本文將從多個方面對蝴蝶優化演算法Python版…

    編程 2025-04-29
  • Python清華鏡像下載

    Python清華鏡像是一個高質量的Python開發資源鏡像站,提供了Python及其相關的開發工具、框架和文檔的下載服務。本文將從以下幾個方面對Python清華鏡像下載進行詳細的闡…

    編程 2025-04-29
  • Python程序需要編譯才能執行

    Python 被廣泛應用於數據分析、人工智慧、科學計算等領域,它的靈活性和簡單易學的性質使得越來越多的人喜歡使用 Python 進行編程。然而,在 Python 中程序執行的方式不…

    編程 2025-04-29

發表回復

登錄後才能評論