PyTorch Permute詳解

一、Permute是什麼?

Permute是PyTorch中的一個函數,用於對Tensor進行維度重排,實現Tensor在維度上的特定排序,可以輸入參數來指定新的維度順序。

二、Permute的使用方法

Permute使用方法如下:

torch.permute(*dims)

參數說明:

  • dims:Tensor新的維度順序,用變數長度的可迭代對象表示

使用示例:

import torch

# 定義一個5維張量
a = torch.randn(2, 3, 4, 5, 6)

# 將a的維度順序從[0, 1, 2, 3, 4]調整為[4, 3, 2, 0, 1]
b = a.permute(4, 3, 2, 0, 1)

print('a.shape:', a.shape)
print('b.shape:', b.shape)

輸出結果:

a.shape: torch.Size([2, 3, 4, 5, 6])
b.shape: torch.Size([6, 5, 4, 2, 3])

從輸出結果可以看出,原始張量a的維度順序為[0, 1, 2, 3, 4],新張量b的維度順序為[4, 3, 2, 0, 1]。

三、Permute的示例應用

1、數據增廣

在深度學習中,數據集的增廣是提高模型性能和泛化能力的有效方法。Permute函數可用於數據增廣中的圖像鏡像操作。我們可以通過改變Tensor的維度順序來實現圖像水平和垂直翻轉。

代碼示例:

import torch
import torchvision.transforms.functional as TF

# 定義一張大小為[256, 256]的隨機圖像
img = torch.randn(3, 256, 256)

# 水平翻轉
img_hflip = TF.hflip(img)
img_hflip_permute = img_hflip.permute(0, 2, 1)

# 垂直翻轉
img_vflip = TF.vflip(img)
img_vflip_permute = img_vflip.permute(0, 2, 1)

print('img_hflip_permute.shape:', img_hflip_permute.shape)
print('img_vflip_permute.shape:', img_vflip_permute.shape)

輸出結果:

img_hflip_permute.shape: torch.Size([3, 256, 256])
img_vflip_permute.shape: torch.Size([3, 256, 256])

從結果中可以看出,經過Permute重新排序的圖像張量形狀與原始圖像形狀一致,可以方便地與其他圖像增廣方式進行堆疊。

2、卷積輸出通道交換

在卷積神經網路中,卷積層輸出的特徵圖維度順序通常為[Batch, Channel, Height, Width],每個卷積核對應一個輸出通道。當模型需要跨多個GPU進行並行計算時,由於每個GPU上的卷積層只計算部分特徵圖,輸出通道會被分配到不同的GPU上。為了保證模型的正確性,輸出通道需要重新排列。

代碼示例:

import torch

# 定義一個卷積層輸出的特徵圖張量,假設有16個卷積核
feats_map = torch.randn(32, 16, 28, 28)

# 將特徵圖的通道數從16變為4,並交換通道順序
new_feats_map = feats_map[:, [2, 5, 10, 13,], :, :]
new_feats_map = new_feats_map.permute(0, 2, 3, 1)

print('new_feats_map.shape:', new_feats_map.shape)

輸出結果:

new_feats_map.shape: torch.Size([32, 28, 28, 4])

從結果中可以看出,經過Permute重新排序的特徵圖形狀與原始特徵圖形狀一致,只是通道數被縮減。

3、圖像語義分割

在圖像語義分割中,每個像素都需要指定分類標籤。在訓練過程中,輸入圖像和標籤圖必須同時進行水平和垂直翻轉。為了方便進行網路訓練,我們使用Permute函數對標籤圖的維度順序進行調整。

代碼示例:

import torch

# 定義一張大小為[256, 256]的分類標籤圖
label = torch.randint(low=0, high=20, size=(256, 256))

# 模擬進行水平翻轉和垂直翻轉
label_hflip = label.flip(dims=(1,))
label_vflip = label.flip(dims=(0,))

# 把翻轉後的標籤圖通過Permute函數進行重新排列
label_hvflip_permute = label_hflip.permute(1, 0)
label_vflip_permute = label_vflip.permute(1, 0)

print('label_hvflip_permute.shape:', label_hvflip_permute.shape)
print('label_vflip_permute.shape:', label_vflip_permute.shape)

輸出結果:

label_hvflip_permute.shape: torch.Size([256, 256])
label_vflip_permute.shape: torch.Size([256, 256])

從結果中可以看出,經過Permute重新排序後的標籤圖形狀與原始標籤圖形狀一致,在進行網路訓練時能夠方便地和輸入圖像配對。

原創文章,作者:CUAZ,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/133273.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
CUAZ的頭像CUAZ
上一篇 2024-10-03 23:57
下一篇 2024-10-03 23:57

相關推薦

  • PyTorch模塊簡介

    PyTorch是一個開源的機器學習框架,它基於Torch,是一個Python優先的深度學習框架,同時也支持C++,非常容易上手。PyTorch中的核心模塊是torch,提供一些很好…

    編程 2025-04-27
  • Linux sync詳解

    一、sync概述 sync是Linux中一個非常重要的命令,它可以將文件系統緩存中的內容,強制寫入磁碟中。在執行sync之前,所有的文件系統更新將不會立即寫入磁碟,而是先緩存在內存…

    編程 2025-04-25
  • 神經網路代碼詳解

    神經網路作為一種人工智慧技術,被廣泛應用於語音識別、圖像識別、自然語言處理等領域。而神經網路的模型編寫,離不開代碼。本文將從多個方面詳細闡述神經網路模型編寫的代碼技術。 一、神經網…

    編程 2025-04-25
  • Python輸入輸出詳解

    一、文件讀寫 Python中文件的讀寫操作是必不可少的基本技能之一。讀寫文件分別使用open()函數中的’r’和’w’參數,讀取文件…

    編程 2025-04-25
  • C語言貪吃蛇詳解

    一、數據結構和演算法 C語言貪吃蛇主要運用了以下數據結構和演算法: 1. 鏈表 typedef struct body { int x; int y; struct body *nex…

    編程 2025-04-25
  • Java BigDecimal 精度詳解

    一、基礎概念 Java BigDecimal 是一個用於高精度計算的類。普通的 double 或 float 類型只能精確表示有限的數字,而對於需要高精度計算的場景,BigDeci…

    編程 2025-04-25
  • git config user.name的詳解

    一、為什麼要使用git config user.name? git是一個非常流行的分散式版本控制系統,很多程序員都會用到它。在使用git commit提交代碼時,需要記錄commi…

    編程 2025-04-25
  • Linux修改文件名命令詳解

    在Linux系統中,修改文件名是一個很常見的操作。Linux提供了多種方式來修改文件名,這篇文章將介紹Linux修改文件名的詳細操作。 一、mv命令 mv命令是Linux下的常用命…

    編程 2025-04-25
  • MPU6050工作原理詳解

    一、什麼是MPU6050 MPU6050是一種六軸慣性感測器,能夠同時測量加速度和角速度。它由三個感測器組成:一個三軸加速度計和一個三軸陀螺儀。這個組合提供了非常精細的姿態解算,其…

    編程 2025-04-25
  • Python安裝OS庫詳解

    一、OS簡介 OS庫是Python標準庫的一部分,它提供了跨平台的操作系統功能,使得Python可以進行文件操作、進程管理、環境變數讀取等系統級操作。 OS庫中包含了大量的文件和目…

    編程 2025-04-25

發表回復

登錄後才能評論