線性判別分析(Linear Discriminant Analysis)

一、算法概述

線性判別分析(Linear Discriminant Analysis,LDA)是常用的一種分類算法。它是一種有監督學習方法,也就是需要已知每個樣本的類別標籤。LDA從特徵空間中提取線性判別信息,用於在低維空間中對數據進行分類。

LDA的主要思想是:將樣本投影到一條直線上,使得同類之間的距離儘可能小,不同類之間的距離儘可能大。因此,LDA在降維同時也完成了分類任務。

二、算法流程

LDA的算法流程如下:

  1. 計算每個類別的均值向量。
  2. 計算類間散度矩陣和類內散度矩陣。
  3. 求出最大化目標函數的投影方向。
  4. 降維並進行分類。

三、算法實現

第一步:計算均值向量

設有樣本集$D={x_1, x_2, …, x_n}$,其中$x_i\in R^d$表示第$i$個樣本。將樣本按照類別分開,設有$k$個類別。第$i$類樣本的個數為$n_i$,均值向量為$u_i$。

import numpy as np

def mean_vectors(X, y):
    class_labels = np.unique(y)
    n_classes = class_labels.shape[0]
    mean_vectors = np.zeros((n_classes, X.shape[1]))
    for cl, label in enumerate(class_labels):
        mean_vectors[cl,:] = np.mean(X[y==label], axis=0)
    return mean_vectors

第二步:計算類間散度矩陣和類內散度矩陣

類間散度矩陣$S_B$和類內散度矩陣$S_W$的計算方式如下:

$$S_B = \sum_{i=1}^{k}n_i(u_i – u)(u_i – u)^T$$
$$S_W = \sum_{i=1}^{k}\sum_{x\in D_i}(x-u_i)(x-u_i)^T$$
其中,$u$為所有樣本的均值向量,$S_B$表示類別之間的差異,$S_W$表示類別內部的差異。

def scatter_matrices(X, y):
    mean_vectors = mean_vectors(X, y)
    n_features = X.shape[1]
    s_within = np.zeros((n_features, n_features))
    s_between = np.zeros((n_features, n_features))
    mean_overall = np.mean(X, axis=0)
    for cl, mv in enumerate(mean_vectors):
        class_sc_mat = np.zeros((n_features, n_features))
        for row in X[y == cl]:
            row, mv = row.reshape(n_features, 1), mv.reshape(n_features, 1)
            class_sc_mat += (row - mv).dot((row - mv).T)
        s_within += class_sc_mat        
        n = X[y==cl,:].shape[0]
        mean_diff = (mv - mean_overall).reshape(n_features, 1)
        s_between += n * mean_diff.dot(mean_diff.T)
    return s_within, s_between

第三步:求投影方向

為了求出最大化目標函數的投影方向,需要計算矩陣$S_W^{-1}S_B$的特徵向量和特徵值。在計算投影矩陣的時候,我們可以取最大的$d’$個特徵向量($d’$表示投影后保留的維數)。

def lda(X, y, n_components):
    s_within, s_between = scatter_matrices(X, y)
    eig_vals, eig_vecs = np.linalg.eig(np.linalg.inv(s_within).dot(s_between))
    eig_pairs = [(np.abs(eig_vals[i]), eig_vecs[:,i]) for i in range(len(eig_vals))]
    eig_pairs = sorted(eig_pairs, key=lambda k: k[0], reverse=True)
    proj_mat = eig_pairs[0][1].reshape(X.shape[1],1)
    for i in range(1,n_components):
        proj_mat = np.hstack((proj_mat, eig_pairs[i][1].reshape(X.shape[1],1)))
    return X.dot(proj_mat), proj_mat

第四步:降維並進行分類

使用投影矩陣將樣本投影到新的低維空間中,並使用分類算法進行分類。下面的例子中,我們使用支持向量機作為分類器。

from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC

iris = datasets.load_iris()
X = iris.data
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

X_train_lda, proj_mat = lda(X_train, y_train, 2)
X_test_lda = X_test.dot(proj_mat)

classifier = SVC()
classifier.fit(X_train_lda, y_train)
score = classifier.score(X_test_lda, y_test)
print("Accuracy:", score)

四、總結

LDA作為一種常用的分類算法,在特徵提取和降維方面有着廣泛的應用。通過計算均值向量和散度矩陣,我們可以求出最大化目標函數的投影方向,從而實現對數據的降維和分類。

原創文章,作者:RTIFU,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/316738.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
RTIFU的頭像RTIFU
上一篇 2025-01-09 12:14
下一篇 2025-01-09 12:14

相關推薦

  • Python實現一元線性回歸模型

    本文將從多個方面詳細闡述Python實現一元線性回歸模型的代碼。如果你對線性回歸模型有一些了解,對Python語言也有所掌握,那麼本文將對你有所幫助。在開始介紹具體代碼前,讓我們先…

    編程 2025-04-29
  • Python線性插值法:用數學建模實現精確預測

    本文將會詳細介紹Python線性插值法的實現方式和應用場景。 一、插值法概述 插值法是基於已知數據點得出缺失數據點的一種方法。它常用於科學計算中的函數逼近,是一種基礎的數學建模技術…

    編程 2025-04-27
  • 簡單線性回歸

    一、什麼是簡單線性回歸 簡單線性回歸是一種基本的統計方法,用於描述兩個變量之間的關係。其中一個變量是自變量(解釋變量),另一個變量是因變量(響應變量)。 簡單線性回歸通常用於預測。…

    編程 2025-02-25
  • 線性回歸數據集的實踐與探究

    一、數據集介紹 線性回歸數據集是機器學習中最基礎的數據集之一,通常包含訓練集和測試集。在這個數據集中,我們希望通過輸入不同的特徵值來預測輸出的目標值。 例如,一個房屋價格預測的線性…

    編程 2025-02-24
  • 如何解讀r方很低但是p值顯著的線性回歸

    一、線性回歸r方很低但是p值顯著 在進行線性回歸分析的時候,我們通常會關注兩個指標:r方和p值。r方是反映自變量對因變量的解釋力度,範圍在0~1之間,而p值則是反映自變量與因變量之…

    編程 2025-02-24
  • 深入理解PyTorch中的線性層

    一、線性層簡介 在深度學習中,線性層是最基本的模型之一。PyTorch作為流行的深度學習框架,也提供了很好的線性層構建機制。 線性層(Linear Layer),也稱為全連接層(F…

    編程 2025-02-05
  • 使用Python進行線性回歸預測房價

    一、了解線性回歸 線性回歸是一種用於建立變量之間關係的方法,通常用於預測連續型變量。它假設不同變量之間存在一種線性關係,即每個自變量對因變量的影響是相加的。 在房價預測中,我們可以…

    編程 2025-02-05
  • 線性篩素數詳解

    一、簡介 線性篩素數,顧名思義,是一種用線性時間複雜度求出所有素數的方法。相比於其他素數篩法,線性篩素數更加高效,因此在實際應用中經常被使用。 二、原理 線性篩素數的核心思想是將每…

    編程 2025-01-20
  • 廣義線性混合模型

    一、簡介 廣義線性混合模型是一種統計模型,在許多實際問題中都有廣泛的應用。該模型的主要特點是可以同時處理連續型變量、二元型變量、計數型變量以及其他類型的變量。同時,廣義線性混合模型…

    編程 2025-01-14
  • 用python編寫線性回歸程序,4python簡單線性回歸代碼案例完整

    本文目錄一覽: 1、關於python簡單線性回歸 2、用python寫一個小程序,輸入坐標求線性回歸 3、python線性回歸有哪些方法 4、python怎麼用線性回歸擬合 5、如…

    編程 2025-01-09

發表回復

登錄後才能評論