一、squeeze函數概述
torch.squeeze(input, dim=None, out=None)
在PyTorch中,squeeze函數用於減少tensor張量中維度為1的維數,即將維數為1的維度壓縮掉。
squeeze函數會返回新的tensor張量,原tensor張量不會發生改變,返回的tensor張量與輸入張量在除去指定維度上是否為1的情況下完全相同。若維度不為1或未指定壓縮維度,則會返回原來的張量。
二、squeeze函數的使用
1、壓縮維度
import torch x = torch.randn(1, 3, 1, 5) print(x.size()) y = torch.squeeze(x) print(y.size())
輸出結果為:
torch.Size([1, 3, 1, 5]) torch.Size([3, 5])
在上述示例中,輸入張量的維度為(1,3,1,5),其中維度1與維度3都是1。當執行squeeze(x)操作時,函數會將這兩個維度壓縮掉,最終輸出的結果為(3,5)。
2、指定壓縮維度
import torch x = torch.randn(1, 3, 1, 5) print(x.size()) y = torch.squeeze(x, 0) print(y.size())
輸出結果為:
torch.Size([1, 3, 1, 5]) torch.Size([3, 1, 5])
在上述示例中,squeeze函數會將張量x中維度為0的1壓縮掉,輸出的結果為(3,1,5)。
三、torch.squeeze的局限性
雖然squeeze函數能夠壓縮維度,但是由於函數實現的限制,不能夠減少沒有維度為1的維數。單個張量中,若維度1僅存在一個,那麼該維度不能夠被壓縮掉。
import torch x = torch.randn(3) print(x.size()) y = torch.squeeze(x) print(y.size())
輸出結果為:
torch.Size([3]) torch.Size([3])
因為在上述示例中,張量中僅存在一個維度為3的向量,無法對該維度進行壓縮。
四、結語
以上就是squeeze函數的介紹和應用,雖然函數的操作看似簡單,但在深度學習中,這些小的操作,如切片、拼接、壓縮等都是必不可少的,熟練使用這些小的操作會極大地提升工程師的編程效率。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/182004.html