tensor list详解

一、什么是tensor list

1、Tensor


import torch
a = torch.tensor([1,2,3])

2、列表


lst = [1, 2, 3]

结合两者,即是tensor list


lst = [torch.tensor([1,2,3]), torch.tensor([4,5,6])]

二、tensor list的操作

1、赋值操作


lst[0] = torch.tensor([7,8,9])

2、切片操作


lst_slice = lst[:1]

3、拼接操作


new_lst = lst + [torch.tensor([10,11,12])]

三、tensor list的应用

1、神经网络的前向传播


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(in_features=1000, out_features=100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features=10),
            nn.ReLU(),
            nn.Linear(in_features=10, out_features=1)
        )

    def forward(self, x):
        xs = []
        for layer in self.fc_layers:
            x = layer(x)
            xs.append(x)
        return xs

2、序列标注的解码


def viterbi_decode(self, emissions: List[Tensor], transitions: Tensor,
                   decode_lengths: Optional[List[int]] = None) -> List[List[int]]:

    max_seq_length, batch_size, _ = emissions.shape
    mask = torch.ones(emissions.shape[:2], dtype=bool, device=emissions.device)
    path_scores = []
    path_indices = []
    last_idx = mask.sum(1) - 1
    # 发射概率
    emissions = emissions.permute(1, 0, 2)

    for i, (emission, batch_mask) in enumerate(zip(emissions, mask)):
        path_score, path_index = emission[0].unsqueeze(1), torch.zeros_like(emission[0]).unsqueeze(1).long()

        for j, (transition, score, last) in enumerate(zip(transitions, emission[1:], last_idx)):
            last = last.long()
            # 1、跨度
            broadcast_idx = batch_mask.unsqueeze(1).unsqueeze(2)
            broadcast_last = last.unsqueeze(1).unsqueeze(2)
            current_scores = path_score + transition + score.unsqueeze(1)
            current_scores[last == j] -= transitions[j]
            # 2、更新
            new_path_scores, new_path_indices = current_scores.max(dim=0)
            new_path_scores = torch.where(broadcast_idx, new_path_scores, path_score)
            new_path_indices = torch.where(broadcast_idx, new_path_indices, path_index)
            new_path_indices[last == j] = j

            path_score, path_index = new_path_scores, new_path_indices

        if decode_lengths is not None:
            path_index = [path_index[l, :dl] for l, dl in enumerate(decode_lengths[i].tolist())]
            path_score = [path_score[:dl, l] for l, dl in enumerate(decode_lengths[i].tolist())]

        path_scores.append(path_score)
        path_indices.append(path_index)

    # 计算总分值
    path_scores = [torch.stack(v).sum(0) for v in path_scores]
    return path_indices

四、tensor list的批量化处理

1、转换张量


# tensor列表转换为一个大张量
tensors = [torch.randn(3, 4), torch.randn(5, 6)]
batched_t = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True)

2、打包/解包张量


# 生成一个长度列表,列表中每一个元素代表一个batch的数据的长度(即句子长度)
packed_sequence = torch.nn.utils.rnn.pack_sequence(tensors, enforce_sorted=False)
unpacked_sequence = torch.nn.utils.rnn.pad_packed_sequence(packed_sequence, batch_first=True)

3、引入mask


# 先加入PAD,再加一个mask,把PAD排除在计算外
padded_sequence = nn.utils.rnn.pad_sequence(batched_tokens, batch_first=True, padding_value=vocab.token_to_idx[PAD])
mask = padded_sequence != vocab.token_to_idx[PAD]

五、总结与展望

Tensor List是PyTorch中经常使用到的数据结构,广泛应用于深度学习领域中的多个任务中,如神经网络的前向传播、序列标注的解码等。结合PyTorch自身的优势,我们可以高效地处理大量的数据,并实现了更加优秀的深度学习算法。在未来,Tensor List继续地用于深度学习算法实现中,我们期待Tensor List的技术在处理海量数据、提升模型精度、加速模型训练等多方面都可以有所突破。

原创文章,作者:RUSTH,如若转载,请注明出处:https://www.506064.com/n/360395.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
RUSTHRUSTH
上一篇 2025-02-24 00:33
下一篇 2025-02-24 00:33

相关推荐

  • Tensor to List的使用

    Tensor to List是TensorFlow框架提供的一个非常有用的函数,在很多的深度学习模型中都会用到。它的主要功能是将TensorFlow中的张量(Tensor)转换为P…

    编程 2025-04-29
  • 如何使用Python将输出值赋值给List

    对标题进行精确、简明的解答:本文将从多个方面详细介绍Python如何将输出的值赋值给List。我们将分步骤进行探讨,以便读者更好地理解。 一、变量类型 在介绍如何将输出的值赋值给L…

    编程 2025-04-28
  • Python List查找用法介绍

    在Python中,list是最常用的数据结构之一。在很多场景中,我们需要对list进行查找、筛选等操作。本文将从多个方面对Python List的查找方法进行详细的阐述,包括基本查…

    编程 2025-04-28
  • Python DataFrame转List用法介绍

    Python中常用的数据结构之一为DataFrame,但有时需要针对特定需求将DataFrame转为List。本文从多个方面针对Python DataFrame转List详细介绍。…

    编程 2025-04-27
  • Python中list和tuple的用法及区别

    Python中list和tuple都是常用的数据结构,在开发中用途广泛。本文将从使用方法、特点、存储方式、可变性以及适用场景等多个方面对这两种数据结构做详细的阐述。 一、list和…

    编程 2025-04-27
  • 使用Flutter开发ToDo List App

    本文将会介绍如何使用Flutter开发一个实用的ToDo List App。ToDo List,即待办事项清单,是一种记录人们未处理工作和待办事项的方式。随着日常生活的快节奏,如此…

    编程 2025-04-27
  • Linux sync详解

    一、sync概述 sync是Linux中一个非常重要的命令,它可以将文件系统缓存中的内容,强制写入磁盘中。在执行sync之前,所有的文件系统更新将不会立即写入磁盘,而是先缓存在内存…

    编程 2025-04-25
  • 神经网络代码详解

    神经网络作为一种人工智能技术,被广泛应用于语音识别、图像识别、自然语言处理等领域。而神经网络的模型编写,离不开代码。本文将从多个方面详细阐述神经网络模型编写的代码技术。 一、神经网…

    编程 2025-04-25
  • Python输入输出详解

    一、文件读写 Python中文件的读写操作是必不可少的基本技能之一。读写文件分别使用open()函数中的’r’和’w’参数,读取文件…

    编程 2025-04-25
  • nginx与apache应用开发详解

    一、概述 nginx和apache都是常见的web服务器。nginx是一个高性能的反向代理web服务器,将负载均衡和缓存集成在了一起,可以动静分离。apache是一个可扩展的web…

    编程 2025-04-25

发表回复

登录后才能评论