PyTorch参数初始化详解

一、基础知识

参数初始化是深度学习模型中的重要环节之一,它直接影响到模型的泛化能力和训练效果。在 PyTorch 中给参数赋初值有两种方式,分别是手动设置和自动初始化。在使用手动设置时需要注意参数的大小、含义和初始化方式。同时,PyTorch 提供了一些默认的初始化方式,可以方便地使用。

在 PyTorch 中,模型的参数是存储在 Parameter 类型的变量中,其初始化方式主要包括:

  • 常量初始化
  • 随机初始化
  • 预训练模型初始化

二、常量初始化

常量初始化是最简单的初始化方式,它将参数赋为固定的常量值。这种方式通常不是很常用,但有时候可以用于解决特殊的问题。例如,当我们想固定某些参数的值不变时,可以使用常量初始化。

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128*8*8, 1024)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 10)

        # 将conv1的权重初始化为0.1
        nn.init.constant_(self.conv1.weight, 0.1)

三、随机初始化

随机初始化是最常用的初始化方式之一,它可以使得参数在一定范围内发生变化,增强模型的泛化能力。随机初始化通常包括以下几种方式:

  • 均匀分布初始化
  • 正态分布初始化
  • 截断正态分布初始化
  • 自定义初始化

其中,均匀分布初始化和正态分布初始化在 PyTorch 中均有对应的函数,可以直接使用。

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128*8*8, 1024)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 10)

        # 均匀分布初始化
        nn.init.uniform_(self.fc1.weight, -0.1, 0.1)

        # 正态分布初始化
        nn.init.normal_(self.fc2.weight, mean=0, std=0.01)

四、预训练模型初始化

预训练模型初始化是指使用已经预先训练好的模型来初始化当前模型的参数,该方式在迁移学习中应用广泛。在 PyTorch 中,使用预训练模型进行初始化通常有两种方式,分别是从文件中加载和在线下载。

import torchvision.models as models

# 从文件中加载预训练模型
resnet18 = models.resnet18(pretrained=True)

# 在线下载预训练模型,需要联网
vgg16 = models.vgg16(pretrained=True)

五、自定义初始化

有时候,我们需要使用一些特殊的方式来初始化模型参数,这时就需要自定义初始化函数。在 PyTorch 中,可以使用自定义初始化函数来实现这一点。

import torch.nn as nn

def my_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.1)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(128*8*8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # 自定义初始化
        self.apply(my_init)

六、小结

在 PyTorch 中,参数初始化是深度学习模型中不可或缺的重要部分。了解各种初始化方式的优缺点,根据不同的网络结构和需求选择合适的初始化方式,是提高模型性能的重要手段。

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/160606.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝小蓝
上一篇 2024-11-21 01:17
下一篇 2024-11-21 01:17

相关推荐

  • 三星内存条参数用法介绍

    本文将详细解释三星内存条上面的各种参数,让你更好地了解内存条并选择适合自己的一款。 一、容量大小 容量大小是内存条最基本的参数,一般以GB为单位表示,常见的有2GB、4GB、8GB…

    编程 2025-04-29
  • Python3定义函数参数类型

    Python是一门动态类型语言,不需要在定义变量时显示的指定变量类型,但是Python3中提供了函数参数类型的声明功能,在函数定义时明确定义参数类型。在函数的形参后面加上冒号(:)…

    编程 2025-04-29
  • Python input参数变量用法介绍

    本文将从多个方面对Python input括号里参数变量进行阐述与详解,并提供相应的代码示例。 一、基本介绍 Python input()函数用于获取用户输入。当程序运行到inpu…

    编程 2025-04-29
  • Spring Boot中发GET请求参数的处理

    本文将详细介绍如何在Spring Boot中处理GET请求参数,并给出完整的代码示例。 一、Spring Boot的GET请求参数基础 在Spring Boot中,处理GET请求参…

    编程 2025-04-29
  • Python函数名称相同参数不同:多态

    Python是一门面向对象的编程语言,它强烈支持多态性 一、什么是多态多态是面向对象三大特性中的一种,它指的是:相同的函数名称可以有不同的实现方式。也就是说,不同的对象调用同名方法…

    编程 2025-04-29
  • Python Class括号中的参数用法介绍

    本文将对Python中类的括号中的参数进行详细解析,以帮助初学者熟悉和掌握类的创建以及参数设置。 一、Class的基本定义 在Python中,通过使用关键字class来定义类。类包…

    编程 2025-04-29
  • Hibernate日志打印sql参数

    本文将从多个方面介绍如何在Hibernate中打印SQL参数。Hibernate作为一种ORM框架,可以通过打印SQL参数方便开发者调试和优化Hibernate应用。 一、通过配置…

    编程 2025-04-29
  • 全能编程开发工程师必知——DTD、XML、XSD以及DTD参数实体

    本文将从大体介绍DTD、XML以及XSD三大知识点,同时深入探究DTD参数实体的作用及实际应用场景。 一、DTD介绍 DTD是文档类型定义(Document Type Defini…

    编程 2025-04-29
  • Python可变参数

    本文旨在对Python中可变参数进行详细的探究和讲解,包括可变参数的概念、实现方式、使用场景等多个方面,希望能够对Python开发者有所帮助。 一、可变参数的概念 可变参数是指函数…

    编程 2025-04-29
  • XGBoost n_estimator参数调节

    XGBoost 是 处理结构化数据常用的机器学习框架之一,其中的 n_estimator 参数决定着模型的复杂度和训练速度,这篇文章将从多个方面详细阐述 n_estimator 参…

    编程 2025-04-28

发表回复

登录后才能评论