4.7 KiB
4.7 KiB
torch.nn.Linear 类的方法详解
根据 PyTorch 官方文档(torch.nn.Linear),这是一个实现线性变换(全连接层)的模块,数学公式为 ( y = xA^T + b ),其中 ( A ) 是权重矩阵,( b ) 是偏置(如果启用)。该类继承自 nn.Module
,因此除了特定方法外,还支持 nn.Module
的通用方法(如 train()
、eval()
等),但这里重点讲解文档中明确提到的所有方法,包括 __init__
、forward
和 extra_repr
。文档中未提及其他特定方法。
以下是每个方法的详细讲解,包括方法签名、参数、返回值和解释。参数类型基于 Python 标准类型和 PyTorch 张量规范。
1. __init__
方法(初始化方法)
- 签名:
torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
- 参数:
in_features
(int):每个输入样本的大小(输入特征维度)。out_features
(int):每个输出样本的大小(输出特征维度)。bias
(bool,默认 True):如果为False
,则不学习可加偏置项。device
(torch.device 或 None,默认 None):指定模块的设备(如 CPU 或 GPU)。dtype
(torch.dtype 或 None,默认 None):指定模块的数据类型。
- 返回值:无(这是构造函数,用于初始化模块)。
- 解释:该方法初始化线性变换模块。它会创建权重矩阵 ( A )(形状为
(out_features, in_features)
)和可选的偏置 ( b )(形状为(out_features,)
)。权重和偏置使用均匀分布 (\mathcal{U}(-\sqrt{k}, \sqrt{k})) 初始化,其中 ( k = \frac{1}{\text{in_features}} )。这是一个核心方法,用于构建层,例如nn.Linear(20, 30)
创建一个从 20 维输入到 30 维输出的层。 - 注意:模块支持 TensorFloat32(在支持的硬件上加速),并且在某些 ROCm 设备上,float16 输入会转换为 float32 进行计算。
2. forward
方法(前向传播方法)
- 签名:
forward(input)
(输入参数名为input
,但文档中未显式列出签名,通常推断自描述)。 - 参数:
input
(torch.Tensor):输入张量,形状为(*, H_in)
,其中*
表示任意数量的维度(包括无维度),H_in = in_features
(最后一个维度必须匹配in_features
)。
- 返回值:torch.Tensor,形状为
(*, H_out)
,其中前面的维度与输入相同,H_out = out_features
。 - 解释:这是模块的核心前向传播方法。它对输入应用线性变换 ( y = xA^T + b )(如果
bias=True
)。权重 ( A ) 和偏置 ( b ) 是可学习的参数,通过反向传播更新。该方法是 PyTorch 模块的标准接口,在模型构建中调用如output = layer(input)
时隐式执行。 - 示例(来自官方文档):
这展示了批量输入如何转换为输出形状。
import torch.nn as nn m = nn.Linear(20, 30) # 初始化层 input = torch.randn(128, 20) # 批量大小 128,特征 20 output = m(input) # 前向传播 print(output.size()) # 输出: torch.Size([128, 30])
3. extra_repr
方法(额外字符串表示方法)
- 签名:
extra_repr()
(无参数,继承自nn.Module
)。 - 参数:无。
- 返回值:str,一个字符串,总结模块的参数。
- 解释:该方法返回一个字符串,用于增强模块的打印表示(例如,当你
print(layer)
时)。它通常包括in_features
、out_features
和偏置是否启用(如 "Linear(in_features=20, out_features=30, bias=True)")。这是一个辅助方法,帮助调试和可视化模型结构,不是计算核心。 - 注意:这是
nn.Module
的标准方法,在Linear
中被覆盖以包含特定参数。没有额外示例,但打印模块时会自动调用。
附加说明
- 继承方法:
torch.nn.Linear
继承自nn.Module
,因此支持所有nn.Module
方法,如parameters()
(获取可训练参数)、named_parameters()
(命名参数迭代)、zero_grad()
(梯度清零)等。这些不是Linear
特有,而是通用方法。如果需要讲解整个nn.Module
,可以进一步查询。 - 属性:模块有
weight
(权重张量)和bias
(偏置张量,如果启用),这些是nn.Parameter
类型,可通过优化器更新。 - 使用提示:在构建神经网络时,
Linear
常用于 MLP 等。确保输入的最后一个维度匹配in_features
,否则会报错。文档强调该模块在 float16 输入时的特殊行为(在 ROCm 上转换为 float32)以避免精度问题。
如果需要代码示例、数学推导或特定版本的差异,请提供更多细节!