NN.Embedding详解

NN.Embedding是PyTorch中的一个常用模块,其主要作用是将输入的整数序列转换为密集向量表示。在自然语言处理(NLP)任务中,可以将每个单词表示成一个向量,从而方便进行下一步的计算和处理。

一、创建一个Embedding层

我们可以使用下面的代码创建一个Embedding层:

import torch.nn as nn

# 定义一个 Embedding 层,输入大小为 10,输出大小为 3
embedding = nn.Embedding(10, 3)

这里定义了一个输入大小为10,输出大小为3的Embedding层。其中输入大小10表示一共有10个单词或者是10个离散的特征,输出大小3表示每个单词/特征会被嵌入到一个3维的向量中。

二、查看Embedding层的参数

我们可以通过打印出Embedding层的参数来更好地理解它的作用:

print(embedding.weight)

输出结果如下:

tensor([[-0.4555,  2.0056,  0.3216],
        [-0.8817, -0.8111,  1.1015],
        [-1.0718,  0.6407, -0.2452],
        [-0.1458, -0.4591,  0.3504],
        [ 0.0302,  0.5518, -0.8721],
        [-0.1264, -1.5344,  0.6339],
        [-0.6904, -1.8824, -0.2472],
        [ 0.5966, -0.9738,  0.9559],
        [ 0.0134, -1.3174, -0.3511],
        [ 1.1453,  2.5714,  0.1814]], requires_grad=True)

从上面的输出结果中,我们可以看到一个大小为10×3的矩阵。其中的每一行代表了一个单词/特征的嵌入向量,每个元素都是一个浮点数。这个矩阵的值是在模型训练的过程中学习得到的。

三、输入数据并获取嵌入向量

我们可以使用下面的代码输入一个整数序列并获取嵌入向量:

# 输入一个大小为3的整数序列
input_sequence = torch.LongTensor([1, 5, 3])

# 获取嵌入向量
embedded_sequence = embedding(input_sequence)

print(embedded_sequence)

输出结果如下:

tensor([[-0.8817, -0.8111,  1.1015],
        [-0.1264, -1.5344,  0.6339],
        [-0.1458, -0.4591,  0.3504]], grad_fn=<EmbeddingBackward>)

从上面的输出结果中,我们可以看到一个大小为3×3的矩阵。其中的每一行代表了输入整数序列中对应的单词/特征嵌入向量,可以看到这个结果是和上面我们看到的参数是相一致的。

四、嵌入层在情感分析中的应用举例

举个例子,我们可以使用NN.Embedding来进行情感分析。下面的代码演示了如何将一段文本中的单词转换成嵌入向量,并使用卷积神经网络(CNN)进行情感分类:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SentimentClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_filters, filter_sizes, output_dim, pretrained_embeddings):
        super().__init__()

        # 定义 Embedding 层
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # 定义卷积层
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1, out_channels=num_filters, kernel_size=(fs, embedding_dim)) 
            for fs in filter_sizes
        ])

        # 定义全连接层
        self.fc = nn.Linear(len(filter_sizes) * num_filters, output_dim)

        # 加载预训练的嵌入层权重
        self.embedding.weight.data.copy_(pretrained_embeddings)

    def forward(self, text):
        # text: [batch_size, sent_len]

        # 获取文本中每个单词对应的嵌入向量
        embedded = self.embedding(text) # embedded: [batch_size, sent_len, emb_dim]

        # 调整张量的维度使其适合卷积层的输入
        embedded = embedded.unsqueeze(1) # embedded: [batch_size, 1, sent_len, emb_dim]

        # 运行卷积和池化层
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs] # conved: [batch_size, num_filters, sent_len - filter_sizes[n] + 1]

        # 对每个卷积层的输出进行最大池化
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved] # pooled: [batch_size, num_filters]

        # 把所有的池化层结果拼接到一起,作为全连接层的输入
        cat = self.fc(torch.cat(pooled, dim=1))

        return cat

上面的代码中,我们首先定义了一个SentimentClassifier类,该类继承自nn.Module,实现了一个简单的CNN分类器。其中,我们定义了一个Embedding层,它的参数包括词汇表的大小、嵌入维度以及一个预先训练好的嵌入向量。在前向传递过程中,我们使用了CNN对输入的单词进行特征提取,并经过一个全连接层输出情感分类的结果。

五、小结

NN.Embedding在自然语言处理任务中是一个非常常用的模块,它能够将离散的输入特征转换成密集的向量表示,并被广泛应用于文本分类、句向量生成、对话生成等任务中。

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/310119.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝小蓝
上一篇 2025-01-04 19:32
下一篇 2025-01-04 19:32

相关推荐

  • 神经网络代码详解

    神经网络作为一种人工智能技术,被广泛应用于语音识别、图像识别、自然语言处理等领域。而神经网络的模型编写,离不开代码。本文将从多个方面详细阐述神经网络模型编写的代码技术。 一、神经网…

    编程 2025-04-25
  • Linux sync详解

    一、sync概述 sync是Linux中一个非常重要的命令,它可以将文件系统缓存中的内容,强制写入磁盘中。在执行sync之前,所有的文件系统更新将不会立即写入磁盘,而是先缓存在内存…

    编程 2025-04-25
  • nginx与apache应用开发详解

    一、概述 nginx和apache都是常见的web服务器。nginx是一个高性能的反向代理web服务器,将负载均衡和缓存集成在了一起,可以动静分离。apache是一个可扩展的web…

    编程 2025-04-25
  • Python安装OS库详解

    一、OS简介 OS库是Python标准库的一部分,它提供了跨平台的操作系统功能,使得Python可以进行文件操作、进程管理、环境变量读取等系统级操作。 OS库中包含了大量的文件和目…

    编程 2025-04-25
  • Java BigDecimal 精度详解

    一、基础概念 Java BigDecimal 是一个用于高精度计算的类。普通的 double 或 float 类型只能精确表示有限的数字,而对于需要高精度计算的场景,BigDeci…

    编程 2025-04-25
  • Linux修改文件名命令详解

    在Linux系统中,修改文件名是一个很常见的操作。Linux提供了多种方式来修改文件名,这篇文章将介绍Linux修改文件名的详细操作。 一、mv命令 mv命令是Linux下的常用命…

    编程 2025-04-25
  • MPU6050工作原理详解

    一、什么是MPU6050 MPU6050是一种六轴惯性传感器,能够同时测量加速度和角速度。它由三个传感器组成:一个三轴加速度计和一个三轴陀螺仪。这个组合提供了非常精细的姿态解算,其…

    编程 2025-04-25
  • Python输入输出详解

    一、文件读写 Python中文件的读写操作是必不可少的基本技能之一。读写文件分别使用open()函数中的’r’和’w’参数,读取文件…

    编程 2025-04-25
  • git config user.name的详解

    一、为什么要使用git config user.name? git是一个非常流行的分布式版本控制系统,很多程序员都会用到它。在使用git commit提交代码时,需要记录commi…

    编程 2025-04-25
  • 详解eclipse设置

    一、安装与基础设置 1、下载eclipse并进行安装。 2、打开eclipse,选择对应的工作空间路径。 File -> Switch Workspace -> [选择…

    编程 2025-04-25

发表回复

登录后才能评论