一、什麼是tensor list
1、Tensor
import torch
a = torch.tensor([1,2,3])
2、列表
lst = [1, 2, 3]
結合兩者,即是tensor list
lst = [torch.tensor([1,2,3]), torch.tensor([4,5,6])]
二、tensor list的操作
1、賦值操作
lst[0] = torch.tensor([7,8,9])
2、切片操作
lst_slice = lst[:1]
3、拼接操作
new_lst = lst + [torch.tensor([10,11,12])]
三、tensor list的應用
1、神經網絡的前向傳播
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc_layers = nn.Sequential(
nn.Linear(in_features=1000, out_features=100),
nn.ReLU(),
nn.Linear(in_features=100, out_features=10),
nn.ReLU(),
nn.Linear(in_features=10, out_features=1)
)
def forward(self, x):
xs = []
for layer in self.fc_layers:
x = layer(x)
xs.append(x)
return xs
2、序列標註的解碼
def viterbi_decode(self, emissions: List[Tensor], transitions: Tensor,
decode_lengths: Optional[List[int]] = None) -> List[List[int]]:
max_seq_length, batch_size, _ = emissions.shape
mask = torch.ones(emissions.shape[:2], dtype=bool, device=emissions.device)
path_scores = []
path_indices = []
last_idx = mask.sum(1) - 1
# 發射概率
emissions = emissions.permute(1, 0, 2)
for i, (emission, batch_mask) in enumerate(zip(emissions, mask)):
path_score, path_index = emission[0].unsqueeze(1), torch.zeros_like(emission[0]).unsqueeze(1).long()
for j, (transition, score, last) in enumerate(zip(transitions, emission[1:], last_idx)):
last = last.long()
# 1、跨度
broadcast_idx = batch_mask.unsqueeze(1).unsqueeze(2)
broadcast_last = last.unsqueeze(1).unsqueeze(2)
current_scores = path_score + transition + score.unsqueeze(1)
current_scores[last == j] -= transitions[j]
# 2、更新
new_path_scores, new_path_indices = current_scores.max(dim=0)
new_path_scores = torch.where(broadcast_idx, new_path_scores, path_score)
new_path_indices = torch.where(broadcast_idx, new_path_indices, path_index)
new_path_indices[last == j] = j
path_score, path_index = new_path_scores, new_path_indices
if decode_lengths is not None:
path_index = [path_index[l, :dl] for l, dl in enumerate(decode_lengths[i].tolist())]
path_score = [path_score[:dl, l] for l, dl in enumerate(decode_lengths[i].tolist())]
path_scores.append(path_score)
path_indices.append(path_index)
# 計算總分值
path_scores = [torch.stack(v).sum(0) for v in path_scores]
return path_indices
四、tensor list的批量化處理
1、轉換張量
# tensor列錶轉換為一個大張量
tensors = [torch.randn(3, 4), torch.randn(5, 6)]
batched_t = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)
2、打包/解包張量
# 生成一個長度列表,列表中每一個元素代表一個batch的數據的長度(即句子長度)
packed_sequence = torch.nn.utils.rnn.pack_sequence(tensors, enforce_sorted=False)
unpacked_sequence = torch.nn.utils.rnn.pad_packed_sequence(packed_sequence, batch_first=True)
3、引入mask
# 先加入PAD,再加一個mask,把PAD排除在計算外
padded_sequence = nn.utils.rnn.pad_sequence(batched_tokens, batch_first=True, padding_value=vocab.token_to_idx[PAD])
mask = padded_sequence != vocab.token_to_idx[PAD]
五、總結與展望
Tensor List是PyTorch中經常使用到的數據結構,廣泛應用於深度學習領域中的多個任務中,如神經網絡的前向傳播、序列標註的解碼等。結合PyTorch自身的優勢,我們可以高效地處理大量的數據,並實現了更加優秀的深度學習算法。在未來,Tensor List繼續地用於深度學習算法實現中,我們期待Tensor List的技術在處理海量數據、提升模型精度、加速模型訓練等多方面都可以有所突破。
原創文章,作者:RUSTH,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/360395.html