### torch.nn.Linear 类的方法详解 根据 PyTorch 官方文档(torch.nn.Linear),这是一个实现线性变换(全连接层)的模块,数学公式为 \( y = xA^T + b \),其中 \( A \) 是权重矩阵,\( b \) 是偏置(如果启用)。该类继承自 `nn.Module`,因此除了特定方法外,还支持 `nn.Module` 的通用方法(如 `train()`、`eval()` 等),但这里重点讲解文档中明确提到的所有方法,包括 `__init__`、`forward` 和 `extra_repr`。文档中未提及其他特定方法。 以下是每个方法的详细讲解,包括方法签名、参数、返回值和解释。参数类型基于 Python 标准类型和 PyTorch 张量规范。 #### 1. `__init__` 方法(初始化方法) - **签名**:`torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)` - **参数**: - `in_features`(int):每个输入样本的大小(输入特征维度)。 - `out_features`(int):每个输出样本的大小(输出特征维度)。 - `bias`(bool,默认 True):如果为 `False`,则不学习可加偏置项。 - `device`(torch.device 或 None,默认 None):指定模块的设备(如 CPU 或 GPU)。 - `dtype`(torch.dtype 或 None,默认 None):指定模块的数据类型。 - **返回值**:无(这是构造函数,用于初始化模块)。 - **解释**:该方法初始化线性变换模块。它会创建权重矩阵 \( A \)(形状为 `(out_features, in_features)`)和可选的偏置 \( b \)(形状为 `(out_features,)`)。权重和偏置使用均匀分布 \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) 初始化,其中 \( k = \frac{1}{\text{in_features}} \)。这是一个核心方法,用于构建层,例如 `nn.Linear(20, 30)` 创建一个从 20 维输入到 30 维输出的层。 - **注意**:模块支持 TensorFloat32(在支持的硬件上加速),并且在某些 ROCm 设备上,float16 输入会转换为 float32 进行计算。 #### 2. `forward` 方法(前向传播方法) - **签名**:`forward(input)`(输入参数名为 `input`,但文档中未显式列出签名,通常推断自描述)。 - **参数**: - `input`(torch.Tensor):输入张量,形状为 `(*, H_in)`,其中 `*` 表示任意数量的维度(包括无维度),`H_in = in_features`(最后一个维度必须匹配 `in_features`)。 - **返回值**:torch.Tensor,形状为 `(*, H_out)`,其中前面的维度与输入相同,`H_out = out_features`。 - **解释**:这是模块的核心前向传播方法。它对输入应用线性变换 \( y = xA^T + b \)(如果 `bias=True`)。权重 \( A \) 和偏置 \( b \) 是可学习的参数,通过反向传播更新。该方法是 PyTorch 模块的标准接口,在模型构建中调用如 `output = layer(input)` 时隐式执行。 - **示例**(来自官方文档): ```python import torch.nn as nn m = nn.Linear(20, 30) # 初始化层 input = torch.randn(128, 20) # 批量大小 128,特征 20 output = m(input) # 前向传播 print(output.size()) # 输出: torch.Size([128, 30]) ``` 这展示了批量输入如何转换为输出形状。 #### 3. `extra_repr` 方法(额外字符串表示方法) - **签名**:`extra_repr()`(无参数,继承自 `nn.Module`)。 - **参数**:无。 - **返回值**:str,一个字符串,总结模块的参数。 - **解释**:该方法返回一个字符串,用于增强模块的打印表示(例如,当你 `print(layer)` 时)。它通常包括 `in_features`、`out_features` 和偏置是否启用(如 "Linear(in_features=20, out_features=30, bias=True)")。这是一个辅助方法,帮助调试和可视化模型结构,不是计算核心。 - **注意**:这是 `nn.Module` 的标准方法,在 `Linear` 中被覆盖以包含特定参数。没有额外示例,但打印模块时会自动调用。 ### 附加说明 - **继承方法**:`torch.nn.Linear` 继承自 `nn.Module`,因此支持所有 `nn.Module` 方法,如 `parameters()`(获取可训练参数)、`named_parameters()`(命名参数迭代)、`zero_grad()`(梯度清零)等。这些不是 `Linear` 特有,而是通用方法。如果需要讲解整个 `nn.Module`,可以进一步查询。 - **属性**:模块有 `weight`(权重张量)和 `bias`(偏置张量,如果启用),这些是 `nn.Parameter` 类型,可通过优化器更新。 - **使用提示**:在构建神经网络时,`Linear` 常用于 MLP 等。确保输入的最后一个维度匹配 `in_features`,否则会报错。文档强调该模块在 float16 输入时的特殊行为(在 ROCm 上转换为 float32)以避免精度问题。 如果需要代码示例、数学推导或特定版本的差异,请提供更多细节!