深度学习之hierarchicalsoftmax

一、什么是hierarchicalsoftmax

hierarchicalsoftmax是一种用于优化神经网络中softmax函数计算速度的方法。在传统的softmax函数中,需要对每个候选类别计算概率,这导致计算量呈指数级增长。Huffman树是一种二叉树结构,旨在通过分配更短的编码来最小化字符编码的平均长度。基于Huffman树,hierarchicalsoftmax可以将softmax计算复杂度减少为O(log(n)),其中n是类别总数。

在hierarchicalsoftmax中,所有可能的输出类别都被视为二叉树的节点。每个节点都有一段唯一的编码。在推断时,softmax操作沿着树从根节点开始移动,直到找到输出节点并计算对应的概率。

通俗地理解,hierarchicalsoftmax可以看作是将原本softmax中的每个类别映射为一个节点,然后用二叉树的形式展示。每个节点都可以得到一个唯一的binary code。在实际中,用hierarchicalsoftmax代替传统softmax可以大幅度地减少参数大小和模型复杂度,从而加速模型训练和推理。

二、hierarchicalsoftmax的优点

1.减少模型参数:hierarchicalsoftmax通过二叉树结构来组织类别标签,有效降低了softmax的计算复杂度。相应的,也能减少模型的参数数量和计算时间。

2.更快的训练和推理速度:传统softmax方法需要计算每个输出类别的概率值,而hierarchicalsoftmax只需要向下遍历Huffman树即可。因此,hierarchicalsoftmax可以显著减少计算量,提高训练和推理效率。

3.适合处理大规模分类问题:由于传统的softmax方法需要计算所有可能的类别的概率值,因此对于大规模分类问题计算量过大,而hierarchicalsoftmax可以在常规硬件设备上处理上百万个类别的分类问题。

三、如何使用hierarchicalsoftmax

在tensorflow中,可以通过设置softmax_weights和softmax_biases的参数实现hierarchicalsoftmax。先用一个batch对模型进行一次forward,通过实例化HuffmanTree类,将训练数据传入。创建完成Huffman树后,即可计算对应节点的编码和概率值。


import tensorflow as tf
from tensorflow.contrib.framework import nest
from tensorflow.contrib.rnn import LSTMStateTuple
from tensorflow.python.ops.rnn import dynamic_rnn

logit = tf.contrib.layers.fully_connected(
    inputs=last_outputs,
    num_outputs=output_dimension,
    activation_fn=None,
    weights_initializer=tf.truncated_normal_initializer(stddev=1e-4),
    biases_initializer=tf.zeros_initializer(),
    scope='hierarchical_softmax_logit'
)

# create a softmax weight matrix for each branch
hierarchical_softmax_weights = [tf.Variable(
    tf.truncated_normal([branch_size, output_dimension], stddev=1e-4),
    name="hierarchical_softmax_weights_%d" % i)
for i, branch_size in enumerate(huffman_tree.branch_sizes)]

# split the variables into a list for each branch
hierarchical_softmax_weights_branches = nest.pack_sequence_as(
    structure=huffman_tree.branch_sizes,
    flat_sequence=hierarchical_softmax_weights)

# compute the logits for each branch
logits = nest.map_structure(
    lambda w: tf.matmul(last_outputs, w, transpose_b=True),
    hierarchical_softmax_weights_branches)

# induce a softmax on them
softmaxes = nest.map_structure(
    lambda l: tf.nn.softmax(l, dim=1),
    logits)

# assign unique paths from the root node to all of the leafs
hierarchical_paths = huffman_tree.paths()

# get the full word embeddings for each unique word in the tree
full_embeddings = tf.gather(
    params=full_embeddings,
    indices=huffman_tree.word_ids())

weights_t = tf.transpose(hierarchical_softmax_weights_branches, [1, 0, 2])
weights_flat = tf.reshape(weights_t, [-1, output_dimension])

biases_flat = tf.Variable(
    tf.zeros([tf.reduce_sum(huffman_tree.branch_sizes)]),
    name="hierarchical_softmax_biases")

hierarchical_softmax_biases_branches = tf.split(
    biases_flat, huffman_tree.branch_sizes)

biases = nest.pack_sequence_as(
    structure=hierarchical_softmax_weights_branches,
    flat_sequence=hierarchical_softmax_biases_branches)

l_prods = nest.map_structure(
    lambda s, l: tf.matmul(l, s, transpose_b=True), hierarchical_paths, softmaxes)

prods = tf.reduce_prod(l_prods, axis=0)

dot = tf.matmul(full_embeddings, weights_flat, transpose_b=True)

z = tf.add(dot, biases_flat)

pred = tf.multiply(z, prods)

prediction = tf.nn.softmax(pred, 1)

loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=pred)

四、hierarchicalsoftmax的局限性和应用

1.局限性:由于hierarchicalsoftmax是依靠Huffman树构建的,因此其对类别分布的偏置和采样方式较为敏感。在类别分布不均衡的情况下,Huffman树的构建往往会是非常非常慢,甚至不可用。

2.应用:hierarchicalsoftmax在大规模分类问题中表现出了优异的性能。例如,可以通过构建超大型的分类词典以实现高级的文本语言建模。hierarchicalsoftmax也可以用于其他类型的分类问题,例如多标签分类。

五、小结

hierarchicalsoftmax是一种用于提高softmax计算速度的算法。相比传统softmax,改进方案通过构建Huffman树,将分类问题以一种更加简洁的方式来展示。在大规模分类问题中,hierarchicalsoftmax是一种值得尝试的算法。

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/237035.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝小蓝
上一篇 2024-12-12 12:03
下一篇 2024-12-12 12:03

相关推荐

  • 深度查询宴会的文化起源

    深度查询宴会,是指通过对一种文化或主题的深度挖掘和探究,为参与者提供一次全方位的、深度体验式的文化品尝和交流活动。本文将从多个方面探讨深度查询宴会的文化起源。 一、宴会文化的起源 …

    编程 2025-04-29
  • Python下载深度解析

    Python作为一种强大的编程语言,在各种应用场景中都得到了广泛的应用。Python的安装和下载是使用Python的第一步,对这个过程的深入了解和掌握能够为使用Python提供更加…

    编程 2025-04-28
  • Python递归深度用法介绍

    Python中的递归函数是一个函数调用自身的过程。在进行递归调用时,程序需要为每个函数调用开辟一定的内存空间,这就是递归深度的概念。本文将从多个方面对Python递归深度进行详细阐…

    编程 2025-04-27
  • Spring Boot本地类和Jar包类加载顺序深度剖析

    本文将从多个方面对Spring Boot本地类和Jar包类加载顺序做详细的阐述,并给出相应的代码示例。 一、类加载机制概述 在介绍Spring Boot本地类和Jar包类加载顺序之…

    编程 2025-04-27
  • 深度解析Unity InjectFix

    Unity InjectFix是一个非常强大的工具,可以用于在Unity中修复各种类型的程序中的问题。 一、安装和使用Unity InjectFix 您可以通过Unity Asse…

    编程 2025-04-27
  • 深度剖析:cmd pip不是内部或外部命令

    一、问题背景 使用Python开发时,我们经常需要使用pip安装第三方库来实现项目需求。然而,在执行pip install命令时,有时会遇到“pip不是内部或外部命令”的错误提示,…

    编程 2025-04-25
  • 动手学深度学习 PyTorch

    一、基本介绍 深度学习是对人工神经网络的发展与应用。在人工神经网络中,神经元通过接受输入来生成输出。深度学习通常使用很多层神经元来构建模型,这样可以处理更加复杂的问题。PyTorc…

    编程 2025-04-25
  • 深度解析Ant Design中Table组件的使用

    一、Antd表格兼容 Antd是一个基于React的UI框架,Table组件是其重要的组成部分之一。该组件可在各种浏览器和设备上进行良好的兼容。同时,它还提供了多个版本的Antd框…

    编程 2025-04-25
  • 深度解析MySQL查看当前时间的用法

    MySQL是目前最流行的关系型数据库管理系统之一,其提供了多种方法用于查看当前时间。在本篇文章中,我们将从多个方面来介绍MySQL查看当前时间的用法。 一、当前时间的获取方法 My…

    编程 2025-04-24
  • 深度学习鱼书的多个方面详解

    一、基础知识介绍 深度学习鱼书是一本系统性的介绍深度学习的图书,主要介绍深度学习的基础知识和数学原理,并且通过相关的应用案例来帮助读者理解深度学习的应用场景和方法。在了解深度学习之…

    编程 2025-04-24

发表回复

登录后才能评论