Files
python/Pytorch/nn/viFunc/nn.Linear.md
2025-09-09 15:56:55 +08:00

4.7 KiB
Raw Permalink Blame History

torch.nn.Linear 类的方法详解

根据 PyTorch 官方文档torch.nn.Linear这是一个实现线性变换全连接层的模块数学公式为 ( y = xA^T + b ),其中 ( A ) 是权重矩阵,( b ) 是偏置(如果启用)。该类继承自 nn.Module,因此除了特定方法外,还支持 nn.Module 的通用方法(如 train()eval() 等),但这里重点讲解文档中明确提到的所有方法,包括 __init__forwardextra_repr。文档中未提及其他特定方法。

以下是每个方法的详细讲解,包括方法签名、参数、返回值和解释。参数类型基于 Python 标准类型和 PyTorch 张量规范。

1. __init__ 方法(初始化方法)

  • 签名torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
  • 参数
    • in_featuresint每个输入样本的大小输入特征维度
    • out_featuresint每个输出样本的大小输出特征维度
    • biasbool默认 True如果为 False,则不学习可加偏置项。
    • devicetorch.device 或 None默认 None指定模块的设备如 CPU 或 GPU
    • dtypetorch.dtype 或 None默认 None指定模块的数据类型。
  • 返回值:无(这是构造函数,用于初始化模块)。
  • 解释:该方法初始化线性变换模块。它会创建权重矩阵 ( A )(形状为 (out_features, in_features))和可选的偏置 ( b )(形状为 (out_features,))。权重和偏置使用均匀分布 (\mathcal{U}(-\sqrt{k}, \sqrt{k})) 初始化,其中 ( k = \frac{1}{\text{in_features}} )。这是一个核心方法,用于构建层,例如 nn.Linear(20, 30) 创建一个从 20 维输入到 30 维输出的层。
  • 注意:模块支持 TensorFloat32在支持的硬件上加速并且在某些 ROCm 设备上float16 输入会转换为 float32 进行计算。

2. forward 方法(前向传播方法)

  • 签名forward(input)(输入参数名为 input,但文档中未显式列出签名,通常推断自描述)。
  • 参数
    • inputtorch.Tensor输入张量形状为 (*, H_in),其中 * 表示任意数量的维度(包括无维度),H_in = in_features(最后一个维度必须匹配 in_features)。
  • 返回值torch.Tensor形状为 (*, H_out),其中前面的维度与输入相同,H_out = out_features
  • 解释:这是模块的核心前向传播方法。它对输入应用线性变换 ( y = xA^T + b )(如果 bias=True)。权重 ( A ) 和偏置 ( b ) 是可学习的参数,通过反向传播更新。该方法是 PyTorch 模块的标准接口,在模型构建中调用如 output = layer(input) 时隐式执行。
  • 示例(来自官方文档):
    import torch.nn as nn
    m = nn.Linear(20, 30)  # 初始化层
    input = torch.randn(128, 20)  # 批量大小 128特征 20
    output = m(input)  # 前向传播
    print(output.size())  # 输出: torch.Size([128, 30])
    
    这展示了批量输入如何转换为输出形状。

3. extra_repr 方法(额外字符串表示方法)

  • 签名extra_repr()(无参数,继承自 nn.Module)。
  • 参数:无。
  • 返回值str一个字符串总结模块的参数。
  • 解释:该方法返回一个字符串,用于增强模块的打印表示(例如,当你 print(layer) 时)。它通常包括 in_featuresout_features 和偏置是否启用(如 "Linear(in_features=20, out_features=30, bias=True)")。这是一个辅助方法,帮助调试和可视化模型结构,不是计算核心。
  • 注意:这是 nn.Module 的标准方法,在 Linear 中被覆盖以包含特定参数。没有额外示例,但打印模块时会自动调用。

附加说明

  • 继承方法torch.nn.Linear 继承自 nn.Module,因此支持所有 nn.Module 方法,如 parameters()(获取可训练参数)、named_parameters()(命名参数迭代)、zero_grad()(梯度清零)等。这些不是 Linear 特有,而是通用方法。如果需要讲解整个 nn.Module,可以进一步查询。
  • 属性:模块有 weight(权重张量)和 bias(偏置张量,如果启用),这些是 nn.Parameter 类型,可通过优化器更新。
  • 使用提示:在构建神经网络时,Linear 常用于 MLP 等。确保输入的最后一个维度匹配 in_features,否则会报错。文档强调该模块在 float16 输入时的特殊行为(在 ROCm 上转换为 float32以避免精度问题。

如果需要代码示例、数学推导或特定版本的差异,请提供更多细节!