PyTorch优化库torch.optim详解

PyTorch是深度学习界最为火热的框架之一,而torch.optim作为PyTorch中的优化库,其不仅为深度学习模型的训练提供了高效、快捷的方式,同时也为各种优化算法的实现提供了标准化的接口。本文将从torch.optim的基本使用出发,逐一解析SGD、RMSprop、Adam、LBFGS、AdamW、Adagrad、SWA等优化算法,并深入剖析torch.optim中的参数设定和优化过程。

一、torch.optim.SGD

随机梯度下降法(Stochastic Gradient Descent, SGD)是一种最基本、最经典、也是最广泛使用的优化算法。在torch.optim中,SGD的默认learning rate为0.1,但在实际使用中,不同的模型和数据常常需要选取不同的learning rate才能达到最好的效果。


    import torch.optim as optim
    
    # 定义模型和损失函数
    model = Net()
    criterion = nn.CrossEntropyLoss()
    
    # 定义优化器
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    # 训练模型
    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # 反向传播
            loss.backward()
            optimizer.step()

在使用SGD时,我们可以通过实验选取最佳的learning rate。如果learning rate太小,模型训练速度会变缓,并有可能陷入局部最优解;如果learning rate太大,模型训练速度会显著提升,但会导致模型过早散开,无法收敛到最优结果。

二、torch.optim.SGD参数

除了learning rate之外,SGD还有一些其他的参数可以调整。其中,“momentum”是SGD的一个重要参数,它可以让模型沿着之前一定的方向继续前进,从而避免陷入局部最优解。这里以momentum=0.9为例进行演示。


    # 定义优化器
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

同时,可以通过设置weight_decay来控制网络的正则化程度,从而尽可能避免过拟合。


    # 定义优化器
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)

三、torch.optim.rmsprop

除了SGD之外,RMSprop也是一种经典的优化算法。在RMSprop中,学习率会因为历史梯度的大小而逐渐减小。


    # 定义优化器
    optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)

其中alpha表示梯度的权重,当alpha越小时,历史梯度的影响就越小。如下例子演示了如何实现动态地改变learning rate。


    # 定义优化器
    optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)
    
    # 动态地更新learning rate
    scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

四、torch.optim.adam

Adam是一种自适应学习率的优化算法,它能够适应不同维度的梯度,并调整学习率。Adam也是当前最为流行的优化算法之一。下面是Adam在PyTorch中的实现方法。


    # 定义优化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))

其中betas分别指的是一阶矩估计的指数衰减率和二阶矩估计的指数衰减率。

五、torch.optim.LBFGS

L-BFGS是一种拟牛顿法,它在模型参数空间中近似Hessian矩阵,并对其进行更新。L-BFGS通常使用于小样本、高维度的问题上。下面是如何在PyTorch中使用L-BFGS。


    # 定义优化器
    optimizer = optim.LBFGS(model.parameters(), lr=0.01, max_iter=20)
    
    # 定义训练函数
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        return loss
        
    # 训练模型
    for epoch in range(num_epochs):
        optimizer.step(closure)

在使用L-BFGS时,我们可以通过设置max_iter来控制迭代次数。同时,由于L-BFGS只能处理一小批数据,因此在每个迭代步骤中都需要先清空优化器。

六、torch.optim.adamw

AdamW与Adam非常相似,但AdamW加入了权重衰减(Weight Decay)。权重衰减能够限制W的大小,避免过拟合。下面是如何在PyTorch中使用AdamW。


    # 定义优化器
    optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)

七、torch.optim.Adagrad

Adagrad是一种累计梯度优化算法,可以自适应地调整学习率,更快更好地训练模型。下面是如何在PyTorch中使用Adagrad的方法。


    # 定义优化器
    optimizer = optim.Adagrad(model.parameters(), lr=0.01, weight_decay=0.01)

八、torch.optim.adam参数

Adam优化器中还有一些参数需要设置,下面将逐一介绍。

1)eps:用于稳定模型求解,避免出现除以0的情况。


    # 定义优化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, eps=1e-08)

2)amsgrad:是否使用AMSGrad方法来保证梯度的平稳性。


    # 定义优化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)

九、torch.optim.swa_utils

SWA是一种基于SGD的优化算法,它通过计算所有epoch中的模型的均值,并将其作为最终模型。SWA具有快速收敛和较好的泛化能力,因此在深度学习领域中非常受欢迎。

在PyTorch中使用SWA需要进行如下操作:


    # 导入swa_utils
    from torch.optim.swa_utils import AveragedModel, SWALR
    
    # 定义优化器
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    
    # 运用SWA策略
    swa_model = AveragedModel(model)
    swa_start = 10
    swa_scheduler = SWALR(optimizer, swa_lr=0.05)
    
    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            
            # 前向传播
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # 反向传播
            loss.backward()
            optimizer.step()
            
            # SWA模型更新
            if epoch > swa_start:
                swa_model.update_parameters(model)
                swa_scheduler.step()
                
        # 保存模型
        if epoch > swa_start:
            swa_model = swa_model.to('cpu')
            torch.save({
                'epoch': epoch,
                'model_state_dict': swa_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
            }, './checkpoints/model_{}.pth'.format(epoch))
            swa_model.cuda()

其中,AveragedModel用于计算所有epoch中的模型的均值,通过SWALR策略来动态地更新learning rate。

十、torch.optim.Optimizer

在PyTorch中,所有的优化器都继承自torch.optim.Optimizer类。通过该类,我们可以方便地实现各种优化算法。下面是Optimzier类的一个简单示例。


    class MyOptimizer(torch.optim.Optimizer):
        def __init__(self, params, lr=0.1, momentum=0.9):
            defaults = dict(lr=lr, momentum=momentum)
            super(MyOptimizer, self).__init__(params, defaults)
            
        def __setstate__(self, state):
            super(MyOptimizer, self).__setstate__(state)
        
        def step(self, closure=None):
            loss = None
            if closure is not None:
                loss = closure()
                
            for group in self.param_groups:
                lr = group['lr']
                momentum = group['momentum']
                
                for p in group['params']:
                    if p.grad is None:
                        continue
                        
                    d_p = p.grad.data
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                            buf.mul_(momentum).add_(d_p)
                            p.data.add_(-lr, buf)
                            
            return loss

通过自定义Optimzer,我们可以更加灵活地实现各种优化算法。

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/304655.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝小蓝
上一篇 2025-01-01 11:05
下一篇 2025-01-01 11:05

相关推荐

  • PyTorch模块简介

    PyTorch是一个开源的机器学习框架,它基于Torch,是一个Python优先的深度学习框架,同时也支持C++,非常容易上手。PyTorch中的核心模块是torch,提供一些很好…

    编程 2025-04-27
  • Linux sync详解

    一、sync概述 sync是Linux中一个非常重要的命令,它可以将文件系统缓存中的内容,强制写入磁盘中。在执行sync之前,所有的文件系统更新将不会立即写入磁盘,而是先缓存在内存…

    编程 2025-04-25
  • 神经网络代码详解

    神经网络作为一种人工智能技术,被广泛应用于语音识别、图像识别、自然语言处理等领域。而神经网络的模型编写,离不开代码。本文将从多个方面详细阐述神经网络模型编写的代码技术。 一、神经网…

    编程 2025-04-25
  • Python输入输出详解

    一、文件读写 Python中文件的读写操作是必不可少的基本技能之一。读写文件分别使用open()函数中的’r’和’w’参数,读取文件…

    编程 2025-04-25
  • C语言贪吃蛇详解

    一、数据结构和算法 C语言贪吃蛇主要运用了以下数据结构和算法: 1. 链表 typedef struct body { int x; int y; struct body *nex…

    编程 2025-04-25
  • Java BigDecimal 精度详解

    一、基础概念 Java BigDecimal 是一个用于高精度计算的类。普通的 double 或 float 类型只能精确表示有限的数字,而对于需要高精度计算的场景,BigDeci…

    编程 2025-04-25
  • git config user.name的详解

    一、为什么要使用git config user.name? git是一个非常流行的分布式版本控制系统,很多程序员都会用到它。在使用git commit提交代码时,需要记录commi…

    编程 2025-04-25
  • Linux修改文件名命令详解

    在Linux系统中,修改文件名是一个很常见的操作。Linux提供了多种方式来修改文件名,这篇文章将介绍Linux修改文件名的详细操作。 一、mv命令 mv命令是Linux下的常用命…

    编程 2025-04-25
  • MPU6050工作原理详解

    一、什么是MPU6050 MPU6050是一种六轴惯性传感器,能够同时测量加速度和角速度。它由三个传感器组成:一个三轴加速度计和一个三轴陀螺仪。这个组合提供了非常精细的姿态解算,其…

    编程 2025-04-25
  • Python安装OS库详解

    一、OS简介 OS库是Python标准库的一部分,它提供了跨平台的操作系统功能,使得Python可以进行文件操作、进程管理、环境变量读取等系统级操作。 OS库中包含了大量的文件和目…

    编程 2025-04-25

发表回复

登录后才能评论