一、概述
在 Pytorch 中,我們經常需要處理不同維度的張量數據。unsqueeze() 方法就是用來增加張量的維度的,它會在指定位置增加一維。而其中的 unsqueeze(0) 就是在索引位置 0 上增加一維。
下面我們將從多個方面詳細闡述 unsqueeze(0) 方法。
二、增加維度
unsqueeze(0) 的主要作用就是在張量最前面增加一維。
舉個例子,我們有一個 1 維張量 tensor1 = torch.tensor([1, 2, 3]),如果我們想將其轉換成 2 維張量,可以使用 unsqueeze(0) 方法,在索引位置 0 上增加一維。
import torch
tensor1 = torch.tensor([1, 2, 3])
tensor1_2d = tensor1.unsqueeze(0)
print(tensor1_2d.shape) # 輸出 torch.Size([1, 3])
可以看到,原先的 1 維張量變成了 2 維張量,第一個維度的大小變成了 1。
同理,我們還可以進行多次 unsqueeze(0) 操作,增加多個維度:
import torch
tensor1 = torch.tensor([1, 2, 3])
tensor2 = tensor1.unsqueeze(0).unsqueeze(0)
print(tensor2.shape) # 輸出 torch.Size([1, 1, 3])
可以看到,這次我們進行了兩次 unsqueeze(0),在原先的基礎上增加了兩個維度。
三、在模型中的應用
unsqueeze(0) 方法在深度學習模型中也是常用的操作之一。比如,在卷積神經網絡中,輸入通常是 4 維張量,分別表示 batch_size, channel, height, width。
如果我們的數據集只有一張圖片,那麼 batch_size 就為 1。為了將數據集格式化成網絡所需要的輸入格式,我們就需要將單張圖片的 3 維張量轉換成 4 維張量。這時候 unsqueeze(0) 就能派上用場了。
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 10, kernel_size=3)
def forward(self, input):
x = input.unsqueeze(0) # 將 3 維張量轉換成 4 維張量
out = self.conv(x)
return out
net = Net()
input = torch.randn(1, 28, 28)
output = net(input)
print(output.shape) # 輸出 torch.Size([1, 10, 26, 26])
可以看到,通過 unsqueeze(0),我們將輸入張量從 3 維轉換成了 4 維,成功地將數據集格式化成了網絡所需要的輸入格式。
四、拼接操作
unsqueeze(0) 方法還能和其他張量拼接操作一起使用。
比如,我們有兩個 2 維張量 tensor1 和 tensor2,如果想在第一個維度上進行拼接,就需要對它們進行 unsqueeze(0) 操作,然後再使用 cat() 方法進行拼接。
import torch
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
# 在第一個維度上進行拼接
tensor3 = torch.cat((tensor1.unsqueeze(0), tensor2.unsqueeze(0)), dim=0)
print(tensor3.shape) # 輸出 torch.Size([2, 2, 3])
可以看到,通過 unsqueeze(0) 和 cat() 方法,我們成功地在第一個維度上將兩個 2 維張量拼接成了一個 3 維張量。
五、實現 broadcast_to
unsqueeze(0) 還能用來實現 broadcast_to 操作。broadcast_to 操作是指將一個張量的形狀擴展成指定的形狀。
import torch
def broadcast_to(input, shape):
# 先求出原始形狀和目標形狀的差距
diff = len(shape) - len(input.shape)
# 在 input 最前面增加與目標形狀相差的維數個維度
for _ in range(diff):
input = input.unsqueeze(0)
# 使用 expand 方法擴展形狀
return input.expand(shape)
x = torch.tensor([1, 2, 3])
y = broadcast_to(x, [2, 3])
print(y)
可以看到,使用 unsqueeze(0) 和 expand() 方法,我們成功地將 1 維張量 x 擴展成了形狀為 [2, 3] 的張量 y。
六、總結
unsqueeze(0) 方法是 Pytorch 中常用的增加張量維度的方法之一。它能在指定位置上增加一維,可以與其他拼接操作一起使用,也可以用來實現 broadcast_to 操作。在深度學習模型中,使用 unsqueeze(0) 能夠方便地將數據集格式化成網絡所需要的輸入格式。
使用 unsqueeze(0) 方法需要注意,增加的維度大小是 1,如果需要增加其他大小的維度,需要使用 unsqueeze() 方法,並制定對應的索引位置。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/279470.html