PyTorch優化庫torch.optim詳解

PyTorch是深度學習界最為火熱的框架之一,而torch.optim作為PyTorch中的優化庫,其不僅為深度學習模型的訓練提供了高效、快捷的方式,同時也為各種優化算法的實現提供了標準化的接口。本文將從torch.optim的基本使用出發,逐一解析SGD、RMSprop、Adam、LBFGS、AdamW、Adagrad、SWA等優化算法,並深入剖析torch.optim中的參數設定和優化過程。

一、torch.optim.SGD

隨機梯度下降法(Stochastic Gradient Descent, SGD)是一種最基本、最經典、也是最廣泛使用的優化算法。在torch.optim中,SGD的默認learning rate為0.1,但在實際使用中,不同的模型和數據常常需要選取不同的learning rate才能達到最好的效果。


    import torch.optim as optim
    
    # 定義模型和損失函數
    model = Net()
    criterion = nn.CrossEntropyLoss()
    
    # 定義優化器
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    
    # 訓練模型
    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            
            # 前向傳播
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # 反向傳播
            loss.backward()
            optimizer.step()

在使用SGD時,我們可以通過實驗選取最佳的learning rate。如果learning rate太小,模型訓練速度會變緩,並有可能陷入局部最優解;如果learning rate太大,模型訓練速度會顯著提升,但會導致模型過早散開,無法收斂到最優結果。

二、torch.optim.SGD參數

除了learning rate之外,SGD還有一些其他的參數可以調整。其中,“momentum”是SGD的一個重要參數,它可以讓模型沿着之前一定的方向繼續前進,從而避免陷入局部最優解。這裡以momentum=0.9為例進行演示。


    # 定義優化器
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

同時,可以通過設置weight_decay來控制網絡的正則化程度,從而儘可能避免過擬合。


    # 定義優化器
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)

三、torch.optim.rmsprop

除了SGD之外,RMSprop也是一種經典的優化算法。在RMSprop中,學習率會因為歷史梯度的大小而逐漸減小。


    # 定義優化器
    optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)

其中alpha表示梯度的權重,當alpha越小時,歷史梯度的影響就越小。如下例子演示了如何實現動態地改變learning rate。


    # 定義優化器
    optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha=0.99)
    
    # 動態地更新learning rate
    scheduler = StepLR(optimizer, step_size=1, gamma=0.1)

四、torch.optim.adam

Adam是一種自適應學習率的優化算法,它能夠適應不同維度的梯度,並調整學習率。Adam也是當前最為流行的優化算法之一。下面是Adam在PyTorch中的實現方法。


    # 定義優化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))

其中betas分別指的是一階矩估計的指數衰減率和二階矩估計的指數衰減率。

五、torch.optim.LBFGS

L-BFGS是一種擬牛頓法,它在模型參數空間中近似Hessian矩陣,並對其進行更新。L-BFGS通常使用於小樣本、高維度的問題上。下面是如何在PyTorch中使用L-BFGS。


    # 定義優化器
    optimizer = optim.LBFGS(model.parameters(), lr=0.01, max_iter=20)
    
    # 定義訓練函數
    def closure():
        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        loss.backward()
        return loss
        
    # 訓練模型
    for epoch in range(num_epochs):
        optimizer.step(closure)

在使用L-BFGS時,我們可以通過設置max_iter來控制迭代次數。同時,由於L-BFGS只能處理一小批數據,因此在每個迭代步驟中都需要先清空優化器。

六、torch.optim.adamw

AdamW與Adam非常相似,但AdamW加入了權重衰減(Weight Decay)。權重衰減能夠限制W的大小,避免過擬合。下面是如何在PyTorch中使用AdamW。


    # 定義優化器
    optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=0.01)

七、torch.optim.Adagrad

Adagrad是一種累計梯度優化算法,可以自適應地調整學習率,更快更好地訓練模型。下面是如何在PyTorch中使用Adagrad的方法。


    # 定義優化器
    optimizer = optim.Adagrad(model.parameters(), lr=0.01, weight_decay=0.01)

八、torch.optim.adam參數

Adam優化器中還有一些參數需要設置,下面將逐一介紹。

1)eps:用於穩定模型求解,避免出現除以0的情況。


    # 定義優化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, eps=1e-08)

2)amsgrad:是否使用AMSGrad方法來保證梯度的平穩性。


    # 定義優化器
    optimizer = optim.Adam(model.parameters(), lr=0.001, amsgrad=True)

九、torch.optim.swa_utils

SWA是一種基於SGD的優化算法,它通過計算所有epoch中的模型的均值,並將其作為最終模型。SWA具有快速收斂和較好的泛化能力,因此在深度學習領域中非常受歡迎。

在PyTorch中使用SWA需要進行如下操作:


    # 導入swa_utils
    from torch.optim.swa_utils import AveragedModel, SWALR
    
    # 定義優化器
    optimizer = optim.SGD(model.parameters(), lr=0.1)
    
    # 運用SWA策略
    swa_model = AveragedModel(model)
    swa_start = 10
    swa_scheduler = SWALR(optimizer, swa_lr=0.05)
    
    for epoch in range(num_epochs):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            
            # 前向傳播
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # 反向傳播
            loss.backward()
            optimizer.step()
            
            # SWA模型更新
            if epoch > swa_start:
                swa_model.update_parameters(model)
                swa_scheduler.step()
                
        # 保存模型
        if epoch > swa_start:
            swa_model = swa_model.to('cpu')
            torch.save({
                'epoch': epoch,
                'model_state_dict': swa_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
            }, './checkpoints/model_{}.pth'.format(epoch))
            swa_model.cuda()

其中,AveragedModel用於計算所有epoch中的模型的均值,通過SWALR策略來動態地更新learning rate。

十、torch.optim.Optimizer

在PyTorch中,所有的優化器都繼承自torch.optim.Optimizer類。通過該類,我們可以方便地實現各種優化算法。下面是Optimzier類的一個簡單示例。


    class MyOptimizer(torch.optim.Optimizer):
        def __init__(self, params, lr=0.1, momentum=0.9):
            defaults = dict(lr=lr, momentum=momentum)
            super(MyOptimizer, self).__init__(params, defaults)
            
        def __setstate__(self, state):
            super(MyOptimizer, self).__setstate__(state)
        
        def step(self, closure=None):
            loss = None
            if closure is not None:
                loss = closure()
                
            for group in self.param_groups:
                lr = group['lr']
                momentum = group['momentum']
                
                for p in group['params']:
                    if p.grad is None:
                        continue
                        
                    d_p = p.grad.data
                    if momentum != 0:
                        param_state = self.state[p]
                        if 'momentum_buffer' not in param_state:
                            buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                            buf.mul_(momentum).add_(d_p)
                            p.data.add_(-lr, buf)
                            
            return loss

通過自定義Optimzer,我們可以更加靈活地實現各種優化算法。

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/304655.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2025-01-01 11:05
下一篇 2025-01-01 11:05

相關推薦

  • PyTorch模塊簡介

    PyTorch是一個開源的機器學習框架,它基於Torch,是一個Python優先的深度學習框架,同時也支持C++,非常容易上手。PyTorch中的核心模塊是torch,提供一些很好…

    編程 2025-04-27
  • Linux sync詳解

    一、sync概述 sync是Linux中一個非常重要的命令,它可以將文件系統緩存中的內容,強制寫入磁盤中。在執行sync之前,所有的文件系統更新將不會立即寫入磁盤,而是先緩存在內存…

    編程 2025-04-25
  • 神經網絡代碼詳解

    神經網絡作為一種人工智能技術,被廣泛應用於語音識別、圖像識別、自然語言處理等領域。而神經網絡的模型編寫,離不開代碼。本文將從多個方面詳細闡述神經網絡模型編寫的代碼技術。 一、神經網…

    編程 2025-04-25
  • Python輸入輸出詳解

    一、文件讀寫 Python中文件的讀寫操作是必不可少的基本技能之一。讀寫文件分別使用open()函數中的’r’和’w’參數,讀取文件…

    編程 2025-04-25
  • C語言貪吃蛇詳解

    一、數據結構和算法 C語言貪吃蛇主要運用了以下數據結構和算法: 1. 鏈表 typedef struct body { int x; int y; struct body *nex…

    編程 2025-04-25
  • Java BigDecimal 精度詳解

    一、基礎概念 Java BigDecimal 是一個用於高精度計算的類。普通的 double 或 float 類型只能精確表示有限的數字,而對於需要高精度計算的場景,BigDeci…

    編程 2025-04-25
  • git config user.name的詳解

    一、為什麼要使用git config user.name? git是一個非常流行的分布式版本控制系統,很多程序員都會用到它。在使用git commit提交代碼時,需要記錄commi…

    編程 2025-04-25
  • Linux修改文件名命令詳解

    在Linux系統中,修改文件名是一個很常見的操作。Linux提供了多種方式來修改文件名,這篇文章將介紹Linux修改文件名的詳細操作。 一、mv命令 mv命令是Linux下的常用命…

    編程 2025-04-25
  • MPU6050工作原理詳解

    一、什麼是MPU6050 MPU6050是一種六軸慣性傳感器,能夠同時測量加速度和角速度。它由三個傳感器組成:一個三軸加速度計和一個三軸陀螺儀。這個組合提供了非常精細的姿態解算,其…

    編程 2025-04-25
  • Python安裝OS庫詳解

    一、OS簡介 OS庫是Python標準庫的一部分,它提供了跨平台的操作系統功能,使得Python可以進行文件操作、進程管理、環境變量讀取等系統級操作。 OS庫中包含了大量的文件和目…

    編程 2025-04-25

發表回復

登錄後才能評論