一、Permute是什麼?
Permute是PyTorch中的一個函數,用於對Tensor進行維度重排,實現Tensor在維度上的特定排序,可以輸入參數來指定新的維度順序。
二、Permute的使用方法
Permute使用方法如下:
torch.permute(*dims)
參數說明:
dims
:Tensor新的維度順序,用變量長度的可迭代對象表示
使用示例:
import torch
# 定義一個5維張量
a = torch.randn(2, 3, 4, 5, 6)
# 將a的維度順序從[0, 1, 2, 3, 4]調整為[4, 3, 2, 0, 1]
b = a.permute(4, 3, 2, 0, 1)
print('a.shape:', a.shape)
print('b.shape:', b.shape)
輸出結果:
a.shape: torch.Size([2, 3, 4, 5, 6])
b.shape: torch.Size([6, 5, 4, 2, 3])
從輸出結果可以看出,原始張量a的維度順序為[0, 1, 2, 3, 4],新張量b的維度順序為[4, 3, 2, 0, 1]。
三、Permute的示例應用
1、數據增廣
在深度學習中,數據集的增廣是提高模型性能和泛化能力的有效方法。Permute函數可用於數據增廣中的圖像鏡像操作。我們可以通過改變Tensor的維度順序來實現圖像水平和垂直翻轉。
代碼示例:
import torch
import torchvision.transforms.functional as TF
# 定義一張大小為[256, 256]的隨機圖像
img = torch.randn(3, 256, 256)
# 水平翻轉
img_hflip = TF.hflip(img)
img_hflip_permute = img_hflip.permute(0, 2, 1)
# 垂直翻轉
img_vflip = TF.vflip(img)
img_vflip_permute = img_vflip.permute(0, 2, 1)
print('img_hflip_permute.shape:', img_hflip_permute.shape)
print('img_vflip_permute.shape:', img_vflip_permute.shape)
輸出結果:
img_hflip_permute.shape: torch.Size([3, 256, 256])
img_vflip_permute.shape: torch.Size([3, 256, 256])
從結果中可以看出,經過Permute重新排序的圖像張量形狀與原始圖像形狀一致,可以方便地與其他圖像增廣方式進行堆疊。
2、卷積輸出通道交換
在卷積神經網絡中,卷積層輸出的特徵圖維度順序通常為[Batch, Channel, Height, Width],每個卷積核對應一個輸出通道。當模型需要跨多個GPU進行並行計算時,由於每個GPU上的卷積層只計算部分特徵圖,輸出通道會被分配到不同的GPU上。為了保證模型的正確性,輸出通道需要重新排列。
代碼示例:
import torch
# 定義一個卷積層輸出的特徵圖張量,假設有16個卷積核
feats_map = torch.randn(32, 16, 28, 28)
# 將特徵圖的通道數從16變為4,並交換通道順序
new_feats_map = feats_map[:, [2, 5, 10, 13,], :, :]
new_feats_map = new_feats_map.permute(0, 2, 3, 1)
print('new_feats_map.shape:', new_feats_map.shape)
輸出結果:
new_feats_map.shape: torch.Size([32, 28, 28, 4])
從結果中可以看出,經過Permute重新排序的特徵圖形狀與原始特徵圖形狀一致,只是通道數被縮減。
3、圖像語義分割
在圖像語義分割中,每個像素都需要指定分類標籤。在訓練過程中,輸入圖像和標籤圖必須同時進行水平和垂直翻轉。為了方便進行網絡訓練,我們使用Permute函數對標籤圖的維度順序進行調整。
代碼示例:
import torch
# 定義一張大小為[256, 256]的分類標籤圖
label = torch.randint(low=0, high=20, size=(256, 256))
# 模擬進行水平翻轉和垂直翻轉
label_hflip = label.flip(dims=(1,))
label_vflip = label.flip(dims=(0,))
# 把翻轉後的標籤圖通過Permute函數進行重新排列
label_hvflip_permute = label_hflip.permute(1, 0)
label_vflip_permute = label_vflip.permute(1, 0)
print('label_hvflip_permute.shape:', label_hvflip_permute.shape)
print('label_vflip_permute.shape:', label_vflip_permute.shape)
輸出結果:
label_hvflip_permute.shape: torch.Size([256, 256])
label_vflip_permute.shape: torch.Size([256, 256])
從結果中可以看出,經過Permute重新排序後的標籤圖形狀與原始標籤圖形狀一致,在進行網絡訓練時能夠方便地和輸入圖像配對。
原創文章,作者:CUAZ,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/133273.html