一、基本概念
Tensor.view() 是 PyTorch 中 Tensor 的一種方法,用於改變 Tensor 的形狀或維度,不改變其數據存儲。view() 方法有兩個作用:
- 改變 Tensor 的形狀或維度:原 Tensor 的元素總數只能是被重新定義的 Tensor 的元素總數的整數倍。
- 不改變 Tensor 的數據存儲:將原 Tensor 中的數據存儲於新 Tensor 中的方式與原 Tensor 相同。
這意味着,view() 操作可以有效地避免重新分配內存空間。
二、維度變換示例
下面,我們來看一個示例,將一個 3×3 的矩陣轉換為行向量、列向量、以及一個大小為 9 的一維向量。
import torch
import numpy as np
# 創建一個 3x3 的矩陣
A = torch.tensor(np.random.randint(10, size=(3, 3)))
print("A =", A)
# 轉換為行向量
A_row_vector = A.view(1, -1)
print("A_row_vector =", A_row_vector)
# 轉換為列向量
A_column_vector = A.view(-1, 1)
print("A_column_vector =", A_column_vector)
# 轉換為大小為 9 的一維向量
A_one_dimensional = A.view(-1)
print("A_one_dimensional =", A_one_dimensional)
輸出結果:
A = tensor([[7, 1, 7],
[0, 4, 3],
[4, 8, 7]])
A_row_vector = tensor([[7, 1, 7, 0, 4, 3, 4, 8, 7]])
A_column_vector = tensor([[7],
[1],
[7],
[0],
[4],
[3],
[4],
[8],
[7]])
A_one_dimensional = tensor([7, 1, 7, 0, 4, 3, 4, 8, 7])
可以看到,我們通過 view 操作,快速變換了 Tensor 的維度,不需要重新為 Tensor 分配內存空間。
三、Tensor.view 的使用場景
1、Tensor.view 的作用
Tensor.view() 的意義在於改變 Tensor 的形狀或維度,但是不會改變數據存儲。一般而言,該方法用於以下場景:
- 改變 Tensor 的維度,包括變換形狀、擴展維度等。
- 進行 Tensor 的不同視角查看,比如將圖像從一維向量轉換為二維像素矩陣。
2、Tensor.view 的注意事項
Tensor.view() 有以下幾點需要注意:
- 原 Tensor 的元素總數只能是被重新定義的 Tensor 的元素總數的整數倍。
- 對新 Tensor 進行操作可能會危及原 Tensor 的數據。
- 對新 Tensor 的改動不會影響原 Tensor。
3、Tensor.view 與 Torch.squeeze 的區別
Tensor.view() 和 Torch.squeeze() 都可以用來壓縮 tensor 的維度,兩者的不同點在於:
- Tensor.view() 壓縮維度後需要形狀匹配,不匹配會報錯;而 Torch.squeeze() 不需要形狀匹配,可以自動匹配。
- Tensor.view() 壓縮的維度必須為 1;而 Torch.squeeze() 可以壓縮任意的維度。
下面是一個 Tensor.view() 與 Torch.squeeze() 的示例,將一個 Tensor 中維度為 1 的維度壓縮。
import torch
# 創建一個大小為 1x3x1x2 的 Tensor
A = torch.tensor([[[[1, 2]], [[3, 4]], [[5, 6]]]])
print("Size of A:", A.size())
# 使用 Tensor.view() 壓縮維度
B = A.view(3, 2)
print("Size of B:", B.size())
# 使用 Torch.squeeze() 壓縮維度
C = torch.squeeze(A)
print("Size of C:", C.size())
輸出結果:
Size of A: torch.Size([1, 3, 1, 2])
Size of B: torch.Size([3, 2])
Size of C: torch.Size([3, 2])
可以看到,使用 Tensor.view() 壓縮維度需要事先知道 Tensor 的形狀,而 Torch.squeeze() 可以自動判斷並壓縮指定的維度。
四、小結
Tensor.view() 是 PyTorch 中常用的一種 Tensor 方法,可用於改變 Tensor 的形狀或維度,不改變其數據存儲。它可以快速地進行 Tensor 維度變換,並且可以避免重新分配內存空間。Tensor.view() 的使用場景主要在於改變 Tensor 的維度,包括變換形狀、擴展維度、不同視角查看等場景。在使用 Tensor.view() 注意要求的元素總數必須為被重新定義的 Tensor 的元素總數的整數倍。同時,還可以將 Tensor.view() 與 Torch.squeeze() 相結合使用,來壓縮 tensor 的維度。
原創文章,作者:CZVLY,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/360908.html