一、混合密度網路(MDN)概述
混合密度網路(Mixture Density Network, MDN)是一種基於神經網路的概率模型,可以預測多元輸出的概率分布。MDN的前身為混合高斯模型,其本質是將高斯模型擴展至多元輸出問題上。
一般來說,MDN可以用於建模連續變數或者分類問題中的多元數據輸出。另外,由於MDN還具備自適應能力,因此它可以適用於模型具有複雜雜訊結構的情況下,如自然語言處理、聲音處理、或者圖像識別。
二、混合密度網路(MDN)防止係數為1
在MDN中,防止係數為1是一種常見的技巧,它能夠讓輸出分布變得更加連續,並且防止出現”die-off”的情況。”Die-off”指的是某些輸出的分布出現尾部截斷的問題,當這種情況發生的時候,模型預測的輸出會變得非常敏感。防止係數為1的方法通常是通過將輸出分別乘以一個極小的定值(如1e-5)來實現。
def output_tensor(y_class, y_res):
# 將y_class的形狀從(batch_size, 1)轉換為(batch_size, K)
y_class_flat = tf.reshape(y_class, [-1])
ind = tf.range(tf.shape(y_class_flat)[0]) * K + tf.cast(y_class_flat, tf.int32)
mu = tf.gather(tf.reshape(y_res[:K * size_out], [-1, size_out]), ind)
sigma_hat = tf.exp(tf.gather(y_res[K * size_out:(2 * K + 1) * size_out], ind))
sigma = sigma_hat * tf.pow(1 + tf.pow(sigma_hat, 2) * self.reg, -0.5) # 適當的防止係數
alpha = tf.reshape(tf.nn.softmax(tf.reshape(y_res[(2 * K + 1) * size_out:], [-1, K])), [-1, K])
# 輸出mu, sigma和alpha
return mu, sigma, alpha
三、混合密度網路(MDN)評估
MDN評估一般採用負對數似然(Negative Loglikelihood)來進行。負對數似然是假定觀測值服從預測輸出分布後,在該分布下的似然函數的相反數。
def nll_loss(y_true, y_pred):
mu, sigma, alpha = output_tensor(y_pred[:, :1], y_pred[:, 1:])
gm = tfd.MixtureSameFamily(
mixture_distribution=tfd.Categorical(probs=alpha),
components_distribution=tfd.Normal(
loc=mu,
scale=sigma))
log_likelihood = gm.log_prob(y_true)
return -tf.reduce_mean(log_likelihood)
四、混合密度網路(MDN)做分類
對於分類任務,我們可以在MDN末端採用softmax函數來作為每個輸出類別的概率分布。在這種情況下,我們需要對損失函數進行改進,採用交叉熵損失函數來代替負對數似然損失函數。
def categorical_nll_loss(y_true, y_pred):
# 將y_class的形狀從(batch_size, 1)轉換為(batch_size, K)
y_class_flat = tf.reshape(y_true[:, :1], [-1])
ind = tf.range(tf.shape(y_class_flat)[0]) * K + tf.cast(y_class_flat, tf.int32)
alpha = tf.gather(tf.reshape(y_pred[:, :K * size_out], [-1, size_out]), ind)
log_likelihood = -tf.math.log(alpha)
loss = tf.reduce_mean(log_likelihood, axis=-1)
return loss
五、混合密度網路(MDN)多元回歸
多元回歸任務通常需要預測多個輸出變數,這時我們可以採用多個混合高斯分布來描述多個目標。在這種情況下,我們需要對多元高斯分布求解,具體方法可以採用「聯合分布法」或者「條件概率法」。
K = 3 # 採用3個混合高斯分布作為輸出
model = Sequential()
model.add(Dense(25, activation='relu'))
model.add(Dense(25, activation='relu'))
model.add(Dense(K * size_out + K + 1, activation='linear')) # 輸出為K * size_out個均值,K * size_out個標準差和K個係數
# 定義損失函數為負對數似然函數
model.compile(loss=nll_loss, optimizer=Adam(lr=0.001))
# 進行訓練
model.fit(data_train, label_train, epochs=100)
六、混合密度網路(MDN)多變數輸出
多變數輸出問題是指輸入變數為多維度,輸出變數也為多維度的條件概率分布問題。在這種情況下,我們可以採用獨立的多元高斯分布來分別描述每個輸出維度的條件概率,或者採用圖形模型來描述多維度之間的條件概率關係
def create_model(input_shape, output_shape):
input_layer = Input(shape=input_shape)
hidden = Dense(units=128, activation='relu')(input_layer)
hidden = Dense(units=64, activation='relu')(hidden)
# 為每個輸出維度定義輸出分布
output_layers = []
activations = ['linear', 'exponential', 'sigmoid', 'tanh']
for i in range(output_shape[0]):
out = Dense(units=3 * output_shape[1], activation=activations[i])(hidden)
output_layers.append(out)
output_layer = Concatenate(axis=-1)(output_layers)
model = Model(inputs=[input_layer], outputs=[output_layer])
# 定義損失函數為負對數似然函數
model.compile(optimizer='adam', loss=nll_loss)
return model
七、混合密度網路(MDN)相關論文
A Density Network Approach to Improving the Generalization of Deep Neural Networks
A Mixture Density Network for Bankruptcy Prediction Using Alternative Data
原創文章,作者:LMWIK,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/331151.html