tensor list詳解

一、什麼是tensor list

1、Tensor


import torch
a = torch.tensor([1,2,3])

2、列表


lst = [1, 2, 3]

結合兩者,即是tensor list


lst = [torch.tensor([1,2,3]), torch.tensor([4,5,6])]

二、tensor list的操作

1、賦值操作


lst[0] = torch.tensor([7,8,9])

2、切片操作


lst_slice = lst[:1]

3、拼接操作


new_lst = lst + [torch.tensor([10,11,12])]

三、tensor list的應用

1、神經網路的前向傳播


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(in_features=1000, out_features=100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features=10),
            nn.ReLU(),
            nn.Linear(in_features=10, out_features=1)
        )

    def forward(self, x):
        xs = []
        for layer in self.fc_layers:
            x = layer(x)
            xs.append(x)
        return xs

2、序列標註的解碼


def viterbi_decode(self, emissions: List[Tensor], transitions: Tensor,
                   decode_lengths: Optional[List[int]] = None) -> List[List[int]]:

    max_seq_length, batch_size, _ = emissions.shape
    mask = torch.ones(emissions.shape[:2], dtype=bool, device=emissions.device)
    path_scores = []
    path_indices = []
    last_idx = mask.sum(1) - 1
    # 發射概率
    emissions = emissions.permute(1, 0, 2)

    for i, (emission, batch_mask) in enumerate(zip(emissions, mask)):
        path_score, path_index = emission[0].unsqueeze(1), torch.zeros_like(emission[0]).unsqueeze(1).long()

        for j, (transition, score, last) in enumerate(zip(transitions, emission[1:], last_idx)):
            last = last.long()
            # 1、跨度
            broadcast_idx = batch_mask.unsqueeze(1).unsqueeze(2)
            broadcast_last = last.unsqueeze(1).unsqueeze(2)
            current_scores = path_score + transition + score.unsqueeze(1)
            current_scores[last == j] -= transitions[j]
            # 2、更新
            new_path_scores, new_path_indices = current_scores.max(dim=0)
            new_path_scores = torch.where(broadcast_idx, new_path_scores, path_score)
            new_path_indices = torch.where(broadcast_idx, new_path_indices, path_index)
            new_path_indices[last == j] = j

            path_score, path_index = new_path_scores, new_path_indices

        if decode_lengths is not None:
            path_index = [path_index[l, :dl] for l, dl in enumerate(decode_lengths[i].tolist())]
            path_score = [path_score[:dl, l] for l, dl in enumerate(decode_lengths[i].tolist())]

        path_scores.append(path_score)
        path_indices.append(path_index)

    # 計算總分值
    path_scores = [torch.stack(v).sum(0) for v in path_scores]
    return path_indices

四、tensor list的批量化處理

1、轉換張量


# tensor列錶轉換為一個大張量
tensors = [torch.randn(3, 4), torch.randn(5, 6)]
batched_t = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)

2、打包/解包張量


# 生成一個長度列表,列表中每一個元素代表一個batch的數據的長度(即句子長度)
packed_sequence = torch.nn.utils.rnn.pack_sequence(tensors, enforce_sorted=False)
unpacked_sequence = torch.nn.utils.rnn.pad_packed_sequence(packed_sequence, batch_first=True)

3、引入mask


# 先加入PAD,再加一個mask,把PAD排除在計算外
padded_sequence = nn.utils.rnn.pad_sequence(batched_tokens, batch_first=True, padding_value=vocab.token_to_idx[PAD])
mask = padded_sequence != vocab.token_to_idx[PAD]

五、總結與展望

Tensor List是PyTorch中經常使用到的數據結構,廣泛應用於深度學習領域中的多個任務中,如神經網路的前向傳播、序列標註的解碼等。結合PyTorch自身的優勢,我們可以高效地處理大量的數據,並實現了更加優秀的深度學習演算法。在未來,Tensor List繼續地用於深度學習演算法實現中,我們期待Tensor List的技術在處理海量數據、提升模型精度、加速模型訓練等多方面都可以有所突破。

原創文章,作者:RUSTH,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/360395.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
RUSTH的頭像RUSTH
上一篇 2025-02-24 00:33
下一篇 2025-02-24 00:33

相關推薦

  • Tensor to List的使用

    Tensor to List是TensorFlow框架提供的一個非常有用的函數,在很多的深度學習模型中都會用到。它的主要功能是將TensorFlow中的張量(Tensor)轉換為P…

    編程 2025-04-29
  • 如何使用Python將輸出值賦值給List

    對標題進行精確、簡明的解答:本文將從多個方面詳細介紹Python如何將輸出的值賦值給List。我們將分步驟進行探討,以便讀者更好地理解。 一、變數類型 在介紹如何將輸出的值賦值給L…

    編程 2025-04-28
  • Python List查找用法介紹

    在Python中,list是最常用的數據結構之一。在很多場景中,我們需要對list進行查找、篩選等操作。本文將從多個方面對Python List的查找方法進行詳細的闡述,包括基本查…

    編程 2025-04-28
  • Python DataFrame轉List用法介紹

    Python中常用的數據結構之一為DataFrame,但有時需要針對特定需求將DataFrame轉為List。本文從多個方面針對Python DataFrame轉List詳細介紹。…

    編程 2025-04-27
  • Python中list和tuple的用法及區別

    Python中list和tuple都是常用的數據結構,在開發中用途廣泛。本文將從使用方法、特點、存儲方式、可變性以及適用場景等多個方面對這兩種數據結構做詳細的闡述。 一、list和…

    編程 2025-04-27
  • 使用Flutter開發ToDo List App

    本文將會介紹如何使用Flutter開發一個實用的ToDo List App。ToDo List,即待辦事項清單,是一種記錄人們未處理工作和待辦事項的方式。隨著日常生活的快節奏,如此…

    編程 2025-04-27
  • Linux sync詳解

    一、sync概述 sync是Linux中一個非常重要的命令,它可以將文件系統緩存中的內容,強制寫入磁碟中。在執行sync之前,所有的文件系統更新將不會立即寫入磁碟,而是先緩存在內存…

    編程 2025-04-25
  • 神經網路代碼詳解

    神經網路作為一種人工智慧技術,被廣泛應用於語音識別、圖像識別、自然語言處理等領域。而神經網路的模型編寫,離不開代碼。本文將從多個方面詳細闡述神經網路模型編寫的代碼技術。 一、神經網…

    編程 2025-04-25
  • Python輸入輸出詳解

    一、文件讀寫 Python中文件的讀寫操作是必不可少的基本技能之一。讀寫文件分別使用open()函數中的’r’和’w’參數,讀取文件…

    編程 2025-04-25
  • nginx與apache應用開發詳解

    一、概述 nginx和apache都是常見的web伺服器。nginx是一個高性能的反向代理web伺服器,將負載均衡和緩存集成在了一起,可以動靜分離。apache是一個可擴展的web…

    編程 2025-04-25

發表回復

登錄後才能評論