一、squeeze函数概述
torch.squeeze(input, dim=None, out=None)
在PyTorch中,squeeze函数用于减少tensor张量中维度为1的维数,即将维数为1的维度压缩掉。
squeeze函数会返回新的tensor张量,原tensor张量不会发生改变,返回的tensor张量与输入张量在除去指定维度上是否为1的情况下完全相同。若维度不为1或未指定压缩维度,则会返回原来的张量。
二、squeeze函数的使用
1、压缩维度
import torch x = torch.randn(1, 3, 1, 5) print(x.size()) y = torch.squeeze(x) print(y.size())
输出结果为:
torch.Size([1, 3, 1, 5]) torch.Size([3, 5])
在上述示例中,输入张量的维度为(1,3,1,5),其中维度1与维度3都是1。当执行squeeze(x)操作时,函数会将这两个维度压缩掉,最终输出的结果为(3,5)。
2、指定压缩维度
import torch x = torch.randn(1, 3, 1, 5) print(x.size()) y = torch.squeeze(x, 0) print(y.size())
输出结果为:
torch.Size([1, 3, 1, 5]) torch.Size([3, 1, 5])
在上述示例中,squeeze函数会将张量x中维度为0的1压缩掉,输出的结果为(3,1,5)。
三、torch.squeeze的局限性
虽然squeeze函数能够压缩维度,但是由于函数实现的限制,不能够减少没有维度为1的维数。单个张量中,若维度1仅存在一个,那么该维度不能够被压缩掉。
import torch x = torch.randn(3) print(x.size()) y = torch.squeeze(x) print(y.size())
输出结果为:
torch.Size([3]) torch.Size([3])
因为在上述示例中,张量中仅存在一个维度为3的向量,无法对该维度进行压缩。
四、结语
以上就是squeeze函数的介绍和应用,虽然函数的操作看似简单,但在深度学习中,这些小的操作,如切片、拼接、压缩等都是必不可少的,熟练使用这些小的操作会极大地提升工程师的编程效率。
原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/182004.html