PyTorch固定随机种子的多个方面详解

一、为什么要固定随机种子

在深度学习中,模型的性能经常收到随机性的影响,如初始化、Dropout, BN等。这些随机性导致相同的训练过程,在不同的运行中可能会得到不同的结果或者训练过程会发生不同的情况。这对于实验的可重复性和比较来说是不可取的。为此,我们需要建立一般的方法来固定或控制这些随机化元素,以确保实验的可重复性。

二、PyTorch中随机种子固定的方法

在PyTorch中,我们可以使用以下代码来允许固定随机种子:

import torch
import random
import numpy as np

# Set the random seed manually for reproducibility.
torch.manual_seed(0)
# Set the random seed manually for reproducibility.
np.random.seed(0)
# Set the random seed manually for reproducibility.
random.seed(0)

在上述代码示例中,我们使用PyTorch、NumPy和Python内置随机函数库的函数分别调用了函数torch.manual_seed(0), np.random.seed(0)和random.seed(0)来设置随机种子。

三、 PyTorch中随机种子固定的方法详解

1、torch.backends.cudnn.benchmark

在使用GPU进行训练时,使用cudnn.benchmark可以自动寻求最优的cudnn算法来优化训练速度,但同时也会引入随机因素。因此,在使用可重复的实验时,我们要将其关闭。示例:

import torch.backends.cudnn as cudnn
# Turn off benchmark mode when not needed
cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

2、针对卷积器权重和偏执设置随机种子

在PyTorch中,通过nn.Conv2d类建立的卷积层,如果不显式的设置卷积层的偏置以及权重,PyTorch会自动生成。生成方式参考其源码,它会从均匀分布U(-stdv, stdv)中随机取值,其中stdv为sqrt(1 / n),其中n等于权重的元素个数或者输入通道数。对于这个变量的初始值的设置,会对模型的训练有很大的影响,这也是一个重要的随机因素。我们可以通过如下代码来设置随机种子:

torch.manual_seed(123)
model = nn.Sequential(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
                      nn.ReLU(),
                      nn.MaxPool2d(kernel_size=2, stride=2),
                      nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
                      nn.ReLU(),
                      nn.MaxPool2d(kernel_size=2, stride=2))

# Explicitly define the initialization of the conv parameters
for p in model.parameters():
    if len(p.shape) >=2:
        torch.nn.init.xavier_normal_(p)

3、随机数据生成器的随机种子

在模型训练的过程中,我们需要考虑输入的数据集是否容易受到随机性的影响,如果是,则需要初始化其生成方法的随机种子。示例代码:

# Configure the data generator to have a stable randomization pattern
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.ImageFolder(train_dir, transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.RandomResizedCrop(size),
        transforms.ToTensor(),
        normalize,
    ])),
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)

4、 Dropout和BN层的随机种子

PyTorch中的Dropout和Batch Normalization层,同样会影响神经网络的性能。对于Dropout层的构建,可以手动设置随机种子:

class Model(nn.Module):

    def __init__(self):
        super(Model, self).__init__()
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x):
        x = self.dropout(x)
        return x

model = Model()
model.train()

# Explicitly define dropout's random seed
torch.manual_seed(123)
dropout_output1 = model(torch.randn([1, 3]))

# If we set the random seed again, we should get the same result.
torch.manual_seed(123)
dropout_output2 = model(torch.randn([1, 3]))

print(torch.all(torch.eq(dropout_output1, dropout_output2)))

对于Batch Normalization层,由于其属于一个自适应的过程,其无法通过简单的固定随机种子的方法固定其均值和方差。但是我们可以修改内部的参考数据集,使之与训练之前的数据保持一致,来实现模型的可复现性。代码示例:

class MyModel(nn.Module):

    def __init__(self):
        super(MyModel, self).__init__()
        self.bn = nn.BatchNorm2d(5)

    def forward(self, x):
        x = self.bn(x)
        return x

model = MyModel()

# In order to make the batch_norm layer deterministic,
# we can manually set the running mean/std to a known value.

# Assuming you have inputs of shape (batch_size, channel, height, width)
inputs = torch.randn(32, 5, 24, 24)

# Load the BN layer with data with fixed mean/std.
model.eval()
out = model(inputs)
model.train()
print(out.mean(), out.var())

四、 总结

在PyTorch中,固定随机种子是非常关键的方法之一来确保实验的可重复性。在实际的开发过程中,我们需要考虑到多个方面的随机化元素,如Dropout、BatchNormalization层等等。只有统一设置好随机种子,才能保证各种实验之间的可比性和生产环境的稳定性。

原创文章,作者:BDDCH,如若转载,请注明出处:https://www.506064.com/n/332702.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
BDDCHBDDCH
上一篇 2025-01-24 18:47
下一篇 2025-01-27 13:34

相关推荐

  • 为什么Python不能编译?——从多个方面浅析原因和解决方法

    Python作为很多开发人员、数据科学家和计算机学习者的首选编程语言之一,受到了广泛关注和应用。但与之伴随的问题之一是Python不能编译,这给基于编译的开发和部署方式带来不少麻烦…

    编程 2025-04-29
  • Java判断字符串是否存在多个

    本文将从以下几个方面详细阐述如何使用Java判断一个字符串中是否存在多个指定字符: 一、字符串遍历 字符串是Java编程中非常重要的一种数据类型。要判断字符串中是否存在多个指定字符…

    编程 2025-04-29
  • Python合并多个相同表头文件

    对于需要合并多个相同表头文件的情况,我们可以使用Python来实现快速的合并。 一、读取CSV文件 使用Python中的csv库读取CSV文件。 import csv with o…

    编程 2025-04-29
  • 从多个方面用法介绍yes,but let me review and configure level of access

    yes,but let me review and configure level of access是指在授权过程中,需要进行确认和配置级别控制的全能编程开发工程师。 一、授权确…

    编程 2025-04-29
  • 从多个方面zmjui

    zmjui是一个轻量级的前端UI框架,它实现了丰富的UI组件和实用的JS插件,让前端开发更加快速和高效。本文将从多个方面对zmjui做详细阐述,帮助读者深入了解zmjui,以便更好…

    编程 2025-04-28
  • 学Python用什么编辑器?——从多个方面评估各种Python编辑器

    选择一个适合自己的 Python 编辑器并不容易。除了我们开发的应用程序类型、我们面临的软件架构以及我们的编码技能之外,选择编辑器可能也是我们编写代码时最重要的决定之一。随着许多不…

    编程 2025-04-28
  • 使用easypoi创建多个动态表头

    本文将详细介绍如何使用easypoi创建多个动态表头,让表格更加灵活和具有可读性。 一、创建单个动态表头 easypoi是一个基于POI操作Excel的Java框架,支持通过注解的…

    编程 2025-04-28
  • 创建列表的多个方面

    本文将从多个方面对创建列表进行详细阐述。 一、列表基本概念 列表是一种数据结构,其中元素以线性方式组织,并且具有特殊的序列位置。该位置可以通过索引或一些其他方式进行访问。在编程中,…

    编程 2025-04-28
  • Python多个sheet表合并用法介绍

    本文将从多个方面对Python多个sheet表合并进行详细的阐述。 一、xlrd与xlwt模块的基础知识 xlrd与xlwt是Python中处理Excel文件的重要模块。xlrd模…

    编程 2025-04-27
  • 从多个角度用法介绍lower down

    lower down是一个常用于编程开发中的操作。它可以对某个值或变量进行降低精度的处理,非常适合于一些需要精度不高但速度快的场景。那么,在本文中,我们将从多个角度解析lower …

    编程 2025-04-27

发表回复

登录后才能评论