一、什麼是torch.repeat
torch.repeat 是 pyTorch 中的一個函數,它能將張量沿着指定的維度重複指定次數。重複張量的維度稱為repeat dims,這個函數的參數是一個torch.Size 的元組,包含了每個維度重複的次數。舉個例子,假如有一個形狀為(3,4)的張量,維度為0沿着重複2次,維度為1沿着重複3次,那麼該函數返回一個新的張量,形狀為 (6,12)
二、如何使用torch.repeat
torch.repeat 有個需要注意的地方是它複製張量來產生新的張量,所以需要使用完整的內存。這意味着你需要在使用該函數前將所需要重複的張量複製到GPU或 CPU上。接下來讓我們看一下如何使用這個函數。
# 導入torch
import torch
# 創建一個形狀為(2,2)的張量
x = torch.Tensor([[1,2],[3,4]])
# 沿着第0維和第1維分別重複2次和3次
y = x.repeat(2, 3)
# 打印結果
print(y)
本代碼中,我們首先導入了 pyTorch 庫,並創建了一個形狀為(2,2)的張量 x。接下來,我們使用 repeat 函數對 x 進行重複,其中第一個參數 2 表示第1維將被重複兩次,第二個參數 3 表示第2維將被重複三次。最後,我們打印出了結果 y。 輸出結果如下:
[[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]
[1. 2. 1. 2. 1. 2.]
[3. 4. 3. 4. 3. 4.]]
通過打印結果,我們可以看到張量 x 沿着第0維重複了兩次,沿着第1維重複了三次。重複後的張量 y 的形狀為 (4, 6), 並包含了重複後的值。
三、torch.repeat常見使用場景
torch.repeat 函數的常見應用場景分為以下兩種:
1、將張量複製多次並拼接成一個大張量
假設有一個形狀為(1,3)的張量 x,並將它重複3次並沿着第0維拼接成一個形狀為(3,3)的張量 y。
# 創建一個形狀為(1,3)的張量
x = torch.Tensor([[1,2,3]])
# 沿着第0維重複3次
y = x.repeat(3, 1)
# 打印結果
print(y)
輸出結果如下:
[[1. 2. 3.]
[1. 2. 3.]
[1. 2. 3.]]
2、將張量進行擴維並重複
使用 repeat 函數可以將原始張量擴展為新的張量。舉個例子,假如有一個形狀為(1,3)的張量 x,並將它重複3次並沿着第0維拼接成一個形狀為(3,3)的張量 y。
# 創建一個形狀為(1,3)的張量
x = torch.Tensor([[1,2,3]])
# 在第0維上添加一個新的維度
xx = x.unsqueeze(0)
# 沿着第0維和第1維進行重複
y = xx.repeat(3, 1, 1)
# 打印結果
print(y)
輸出結果如下:
[[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]
[[1. 2. 3.]]]
該例子中,我們首先創建了一個形狀為(1,3)的張量 x。接下來,使用 unsqueeze 函數在第0維上添加一個新的維度。最後,我們使用 repeat 函數沿着第0維和第1維進行重複並打印輸出結果。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/284718.html