深入浅出Pytorch函数——torch.nn.init.calculate_gain

news/2025/2/23 5:49:19

分类目录:《深入浅出Pytorch函数》总目录


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

该函数对于给定的非线性函数,返回推荐的增益值。这些值如下所示:

NonlinearityGain
Linear / Identity 1 1 1
Conv1D / Conv2D / Conv3D 1 1 1
Sigmoid 1 1 1
Tanh 5 3 \frac{5}{3} 35
ReLU 2 \sqrt{2} 2
Leaky Relu 2 1 + negative_slope 2 \sqrt{\frac{2}{1+\text{negative\_slope}^2}} 1+negative_slope22
SELU 4 3 \frac{4}{3} 34

为了实现自归一化神经网络,应该使用nonlinearity='linear'而不是nonlinearity='selu'。这使得初始权重的方差为 1 N \frac{1}{N} N1,这对于在前向通道中引入稳定的固定点是必要的。相比之下,SELU的默认增益牺牲了矩形层中更稳定梯度流的归一化效应。

语法

torch.nn.init.calculate_gain(nonlinearity, param=None)

参数

  • nonlinearity:[nn.functional] 非线性函数名称
  • param:非线性函数的可选参数

实例

# leaky_relu with negative_slope=0.2
gain = nn.init.calculate_gain('leaky_relu', 0.2)  

函数实现

def calculate_gain(nonlinearity, param=None):
    r"""Return the recommended gain value for the given nonlinearity function.
    The values are as follows:

    ================= ====================================================
    nonlinearity      gain
    ================= ====================================================
    Linear / Identity :math:`1`
    Conv{1,2,3}D      :math:`1`
    Sigmoid           :math:`1`
    Tanh              :math:`\frac{5}{3}`
    ReLU              :math:`\sqrt{2}`
    Leaky Relu        :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
    SELU              :math:`\frac{3}{4}`
    ================= ====================================================

    .. warning::
        In order to implement `Self-Normalizing Neural Networks`_ ,
        you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
        This gives the initial weights a variance of ``1 / N``,
        which is necessary to induce a stable fixed point in the forward pass.
        In contrast, the default gain for ``SELU`` sacrifices the normalisation
        effect for more stable gradient flow in rectangular layers.

    Args:
        nonlinearity: the non-linear function (`nn.functional` name)
        param: optional parameter for the non-linear function

    Examples:
        >>> gain = nn.init.calculate_gain('leaky_relu', 0.2)  # leaky_relu with negative_slope=0.2

    .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
    """
    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
        return 1
    elif nonlinearity == 'tanh':
        return 5.0 / 3
    elif nonlinearity == 'relu':
        return math.sqrt(2.0)
    elif nonlinearity == 'leaky_relu':
        if param is None:
            negative_slope = 0.01
        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
            # True/False are instances of int, hence check above
            negative_slope = param
        else:
            raise ValueError("negative_slope {} not a valid number".format(param))
        return math.sqrt(2.0 / (1 + negative_slope ** 2))
    elif nonlinearity == 'selu':
        return 3.0 / 4  # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
    else:
        raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))

http://www.niftyadmin.cn/n/4951490.html

相关文章

如何使用PHP Smarty模板实现静态页面生成

首先,你需要从Smarty官网下载这个神奇的文件。然后,你需要在你的PHP文件中引入Smarty类。就像这样: require_once(Smarty.class.php);现在,我们要创建一个Smarty实例。这就像打开一个新的文件,只不过这个文件是可以和…

STM32 F103C8T6学习笔记8:0.96寸单色OLED显示屏显示字符

使用STM32F103 C8T6 驱动0.96寸单色OLED显示屏: OLED显示屏的驱动,在设计开发中OLED显示屏十分常见,因此今日学习一下。一篇文章从程序到显示都讲通。 文章提供源码、原理解释、测试工程下载,测试效果图展示。 目录 OLED驱动原理—IIC通信…

【Android】设置-显示-屏保-启用时机-默认选中“一律不“

设置-屏保-启用时机-默认选中"一律不" 解决步骤(1)理清思路(2)过程(3)效果图 解决步骤 (1)理清思路 操作步骤: 首先手机进入设置—》点进显示选项—》进入后…

神经网络基础-神经网络补充概念-59-padding

概念 在深度学习中,“padding”(填充)通常是指在卷积神经网络(Convolutional Neural Networks,CNNs)等神经网络层中,在输入数据的周围添加额外的元素(通常是零)&#xf…

spring头约束(全部)

文章目录 spring-mvcspring-aopspring-txspring-contextspring-taskspring-cachespring-jdbcp命令空间spring-jeejmslangoxmutil总结 spring-mvc <beans xmlns"http://www.springframework.org/schema/beans" xmlns:xsi"http://www.w3.org/2001/XMLSchema-…

SQL力扣练习(十一)

目录 1.树节点(608) 示例 1 解法一(case when) 解法二(not in) 2.判断三角形(610) 示例 1 解法一(case when) 解法二(if) 解法三(嵌套if) 3.只出现一次的最大数字(619) 示例 1 解法一(count limit) 解法二(max) 4.有趣的电影(620) 解法一 5.换座位(626) 示例 …

FifthOne:用于矢量搜索的计算机视觉接口

一、说明 数据太多了。数据湖和数据仓库;广阔的像素牧场和充满文字的海洋。找到正确的数据就像大海捞针一样&#xff01;如果你喜欢开源机器学习库 FiftyOne&#xff0c;矢量搜索引擎通过将复杂数据&#xff08;图像的原始像素值、文本文档中的字符&#xff09;转换为称为嵌入矢…

vue中自定义指令

在Vue中&#xff0c;可以通过自定义指令实现对DOM元素的自定义行为和交互。自定义指令可以用于添加事件监听、操作DOM、修改元素样式等。下面是一个简单的示例&#xff0c;展示如何在Vue中定义和使用自定义指令&#xff1a; 创建指令&#xff1a; 在Vue中&#xff0c;可以使用V…