小批量梯度下降法的详细阐述

一、什么是小批量梯度下降法

1、小批量梯度下降法(Mini-batch Gradient Descent, MBGD)是一种介于梯度下降法(GD)和随机梯度下降法(SGD)之间的优化算法。在每次迭代时,它不像GD一样使用所有的训练样本,也不像SGD一样只使用一个样本,而是使用一小部分训练样本(通常是2到1000个)。

2、这样可以在降低随机性和提高算法效率之间取得平衡。与SGD相比,MBGD在每次迭代时使用更多的数据,因此更可能找到全局最优解;与GD相比 ,MBGD在每次迭代时使用更少的数据,因此更快、更节省内存。

二、小批量梯度下降法的实现

1、首先需要定义一个损失函数(Loss Function),通常是均方误差函数,表示预测结果与真实结果之间的差距。

def loss_function(y_true, y_pred):
    return ((y_true-y_pred)**2).mean()

2、然后需要定义模型(Model),通常是一个线性回归模型,也可以是其他类型的模型。

class LinearRegression:
    def __init__(self, n_features):
        self.n_features = n_features
        self.weights = np.random.randn(n_features)
        self.bias = np.random.randn()

    def forward(self, x):
        return np.dot(x, self.weights) + self.bias

    def backward(self, x, y, y_pred):
        n_samples = x.shape[0]
        d_weights = (2 / n_samples) * np.dot(x.T, (y_pred - y))
        d_bias = (2 / n_samples) * np.sum(y_pred - y)
        return d_weights, d_bias

3、在训练过程中,需要随机抽取一小部分样本构成一个batch,计算这个batch的损失和梯度,然后更新模型参数。

def train_step(model, optimizer, x_batch, y_batch):
    # forward
    y_pred = model.forward(x_batch)
    # backward
    d_weights, d_bias = model.backward(x_batch, y_batch, y_pred)
    # update
    optimizer.update(model, d_weights, d_bias)
    # compute loss
    loss = loss_function(y_batch, y_pred)
    return loss

三、小批量梯度下降法的优点和缺点

1、优点:

(1)相对于梯度下降法,小批量梯度下降法更快,内存消耗更少,更适合大规模数据集的训练;

(2)相对于随机梯度下降法,小批量梯度下降法更稳定,更容易找到全局最优解;

(3)由于小批量梯度下降法使用了一部分数据,因此可以获得比随机梯度下降法更准确的梯度,从而更快地收敛。

2、缺点:

(1)需要调整batch size的大小,太小容易增加噪声,太大会占用过多的内存;

(2)需要调整学习率(learning rate)的大小,太小可能导致收敛过慢,太大可能导致震荡不收敛;

(3)需要对数据进行shuffle,否则容易陷入局部最优解。

四、小批量梯度下降法的应用

1、小批量梯度下降法是深度学习中最常用的优化算法之一,广泛应用于神经网络的训练;

2、小批量梯度下降法也可以应用于其他机器学习领域,如线性回归、逻辑回归、支持向量机等;

3、小批量梯度下降法的变种还有动量梯度下降法、Adam等,它们在小批量梯度下降法的基础上加入了一些优化技巧,可以获得更好的性能。

原创文章,作者:SSNSG,如若转载,请注明出处:https://www.506064.com/n/349483.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
SSNSGSSNSG
上一篇 2025-02-15 17:10
下一篇 2025-02-15 17:10

相关推荐

  • 预处理共轭梯度法

    预处理共轭梯度法是一种求解线性方程组的迭代方法,相比直接求解,其具有更高的效率和更快的速度。本文将从几个方面对预处理共轭梯度法进行详细的阐述,并给出完整的代码示例。 一、预处理共轭…

    编程 2025-04-28
  • Python逻辑回归梯度下降法

    本文将通过Python逻辑回归梯度下降法,对于逻辑回归的原理、实现方法和应用进行详细阐述。 一、逻辑回归原理 逻辑回归是一种常用的分类算法,其原理可以用线性回归模型来描述,将线性回…

    编程 2025-04-27
  • index.html怎么打开 – 详细解析

    一、index.html怎么打开看 1、如果你已经拥有了index.html文件,那么你可以直接使用任何一个现代浏览器打开index.html文件,比如Google Chrome、…

    编程 2025-04-25
  • Resetful API的详细阐述

    一、Resetful API简介 Resetful(REpresentational State Transfer)是一种基于HTTP协议的Web API设计风格,它是一种轻量级的…

    编程 2025-04-25
  • 关键路径的详细阐述

    关键路径是项目管理中非常重要的一个概念,它通常指的是项目中最长的一条路径,它决定了整个项目的完成时间。在这篇文章中,我们将从多个方面对关键路径做详细的阐述。 一、概念 关键路径是指…

    编程 2025-04-25
  • AXI DMA的详细阐述

    一、AXI DMA概述 AXI DMA是指Advanced eXtensible Interface Direct Memory Access,是Xilinx公司提供的基于AMBA…

    编程 2025-04-25
  • neo4j菜鸟教程详细阐述

    一、neo4j介绍 neo4j是一种图形数据库,以实现高效的图操作为设计目标。neo4j使用图形模型来存储数据,数据的表述方式类似于实际世界中的网络。neo4j具有高效的读和写操作…

    编程 2025-04-25
  • c++ explicit的详细阐述

    一、explicit的作用 在C++中,explicit关键字可以在构造函数声明前加上,防止编译器进行自动类型转换,强制要求调用者必须强制类型转换才能调用该函数,避免了将一个参数类…

    编程 2025-04-25
  • HTMLButton属性及其详细阐述

    一、button属性介绍 button属性是HTML5新增的属性,表示指定文本框拥有可供点击的按钮。该属性包括以下几个取值: 按钮文本 提交 重置 其中,type属性表示按钮类型,…

    编程 2025-04-25
  • Vim使用教程详细指南

    一、Vim使用教程 Vim是一个高度可定制的文本编辑器,可以在Linux,Mac和Windows等不同的平台上运行。它具有快速移动,复制,粘贴,查找和替换等强大功能,尤其在面对大型…

    编程 2025-04-25

发表回复

登录后才能评论