一、Gumbel Softmax簡介
Gumbel Softmax是一種基於採樣的概率分布生成演算法,它用於從一個具有固定參數的分布中生成一組概率分布。 具體地說,它可以通過使用伯努利分布對樣本進行採樣來生成一個概率分布序列。該演算法的應用包括生成離散變數的序列和特權探測機制等。
通俗點解釋Gumbel分布就是從兩個獨立的均勻分布變數中減去log(-log(U))的值的和,其中U是從均勻分布中隨機採樣的。Gumbel Softmax隨機向量的生成操作包括兩個步驟:
1、從一個Gumbel(0,1)分布採樣,並使用負對數對其進行縮放
2、通過Softmax函數將結果轉換為一個概率向量(一組凸和組件)
import torch
def gumbel_softmax_sample(logits, temperature):
y = logits + torch.randn_like(logits)
return F.softmax(y / temperature, dim=-1)
def gumbel_softmax(logits, temperature, hard=False):
"""
ST-gum: *ST*ochastic *GUM*ble-softmax.
"""
y = gumbel_softmax_sample(logits, temperature)
if hard:
y_hard = torch.zeros_like(y)
max_value, max_index = y.max(dim=-1, keepdim=True)
y_hard.scatter_(dim=-1, index=max_index, value=1.0)
y = (y_hard - y).detach() + y
return y
二、Gumbel Softmax的生成過程
假設我們有一組由骰子擲出的結果構成的序列,該序列中每個骰子擲出的數字之和為10。如果我們知道有多少種不同的序列可以得到這個和,我們就可以得到一個概率分布,該分布揭示了對於所有可能的序列而言,生成和為10的序列的概率是多少。在Gumbel Softmax中,我們使用負對數softmax將擲骰子的操作抽象為隨機變數採樣並將結果映射到概率分布空間上的一組向量。這裡擲骰過程的示例代碼:
num_trials, num_faces, target_value = 1000, 10, 10
dice_faces = torch.randint(1, num_faces + 1, size=(num_trials, target_value))
cumulative_sum = dice_faces.cumsum(dim=1)
indicator = (cumulative_sum == target_value)
target_count = indicator.sum(dim=0)
plt.figure(figsize=(8, 6))
plt.hist(target_count.numpy(), bins=np.arange(6, 40), density=True)
plt.xlabel('Number of successful events')
plt.ylabel('Probability')
plt.title('10d10 success count')
plt.show()
三、Gumbel Softmax的應用場景
Gumbel Softmax的應用場景主要涉及到使用生成模型處理離散數據,具體包括:
1、離散序列生成,即通過輸入生成符合要求的離散序列;
2、文本生成,即用於自然語言處理中,基於巨量的訓練數據進行建模,能夠生成新的語言句子;
3、推薦系統,即基於大數據模型,進行用戶行為分析和個性化推薦。
以上三個應用場景在神經網路建模中佔有重要的地位,由於該模型具備相對較強的分布擬合能力和計算效率,被廣泛應用於當代深度學習模型中。
四、Gumbel Softmax的優缺點
優點:
1、Gumbel Softmax演算法快速且高效,適用於大規模離散數據的建模和模擬;
2、Gumbel Softmax演算法顯著優於其他基於概率分布手段生成離散序列的演算法,具備更強的分布擬合能力和高階統計特性;
缺點:
1、Gumbel Softmax演算法對於小型數據集處理效果並不優秀,對於輸入空間受限的生成模型表現並不理想;
2、Gumbel Softmax演算法存在監督數據缺失問題,對於與數據樣本無法自動識別的離散空間作用不佳;
3、Gumbel Softmax演算法中存在過熱問題,具體來說,由於採樣過程中的雜訊,模型可能會生成具有極小概率的事件,這會對生成效果產生不利影響。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/308662.html