一、單一維度插入
unsqueeze函數是PyTorch中的一個核心函數,廣泛用於神經網絡模型的構建和編寫。它的作用主要是在指定維度上插入維度數為1的新維度。當我們需要在特定位置插入新的維度時,unsqueeze函數就非常有用了。
例如,我們有一個形狀為(3,4)的張量,希望在第二維度上插入新的維度,則代碼如下:
import torch x = torch.ones(3,4) print("原始張量形狀:", x.shape) #(3,4) x = torch.unsqueeze(x, 1) print("插入新維度後的張量形狀:", x.shape) #(3,1,4)
unsqueeze函數中的第一個參數是待操作的張量,第二個參數是要插入的維度的位置。在上面的示例中,我們在第二維度位置插入了新的維度。
二、多維度插入
unsqueeze函數不僅僅可以插入單個維度,還可以在一個張量中同時插入多個維度。可以通過向第二個參數傳入一個元組(tuple)來實現多維度插入操作。
例如,我們有一個形狀為(2,3)的張量,在第二維度和第四個維度上都插入新的維度,代碼如下:
import torch x = torch.ones(2,3) print("原始張量形狀:", x.shape) #(2,3) x = torch.unsqueeze(x, (1,3)) print("插入新維度後的張量形狀:", x.shape) #(2,1,3,1)
多次調用unsqueeze函數也可以實現多維度插入,但使用元組的方式更加簡便。
三、與unsqueeze相反的操作
如果我們在某個維度上插入的維度數為1,那麼此時我們也可以使用squeeze函數將這個維度刪除。
例如,我們有一個形狀為(2,1,3,1)的張量,將第二個維度刪除,則代碼如下:
import torch x = torch.ones(2,1,3,1) print("原始張量形狀:", x.shape) #(2,1,3,1) x = torch.squeeze(x, 1) print("刪除維度後的張量形狀:", x.shape) #(2,3,1)
squeeze函數和unsqueeze函數操作類似,第一個參數是待操作的張量,第二個參數是要刪除的維度的位置。在上面的示例中,我們刪除了維度為1的第二個維度。
四、與其他函數的組合應用
unsqueeze函數與其他函數的組合應用非常廣泛。例如,當我們需要對兩個張量進行相加操作時,需要滿足它們的維度數相同,這時我們可能需要插入一些新的維度,使得兩個張量維度數相同。實現方法就是通過unsqueeze函數將需要插入的維度插入進去。
例如,我們有兩個形狀分別為(2,3)和(1,3,1)的張量,希望將它們通過相加操作合併為一個張量,則代碼如下:
import torch x = torch.ones(2,3) y = torch.ones(1,3,1) x = torch.unsqueeze(x, 0) y = torch.squeeze(y, 2) z = x + y print("合併後的張量形狀:", z.shape) #(2,3,1)
在上面的示例中,我們通過unsqueeze函數給第一個張量插入了一個新的維度,在第一維度位置插入,讓它的形狀變為(1,2,3)。另外,我們還使用了squeeze函數將第二個張量的第二個維度刪除,讓它的形狀變為(1,3)。這樣兩個張量就可以進行相加操作了。
原創文章,作者:RVRXN,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/332496.html