This commit is contained in:
e2hang
2025-09-09 15:56:55 +08:00
parent a5fdeaf70e
commit a8d78878fc
15 changed files with 2265 additions and 0 deletions

View File

@@ -0,0 +1,47 @@
### torch.nn.Linear 类的方法详解
根据 PyTorch 官方文档torch.nn.Linear这是一个实现线性变换全连接层的模块数学公式为 \( y = xA^T + b \),其中 \( A \) 是权重矩阵,\( b \) 是偏置(如果启用)。该类继承自 `nn.Module`,因此除了特定方法外,还支持 `nn.Module` 的通用方法(如 `train()``eval()` 等),但这里重点讲解文档中明确提到的所有方法,包括 `__init__``forward``extra_repr`。文档中未提及其他特定方法。
以下是每个方法的详细讲解,包括方法签名、参数、返回值和解释。参数类型基于 Python 标准类型和 PyTorch 张量规范。
#### 1. `__init__` 方法(初始化方法)
- **签名**`torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)`
- **参数**
- `in_features`int每个输入样本的大小输入特征维度
- `out_features`int每个输出样本的大小输出特征维度
- `bias`bool默认 True如果为 `False`,则不学习可加偏置项。
- `device`torch.device 或 None默认 None指定模块的设备如 CPU 或 GPU
- `dtype`torch.dtype 或 None默认 None指定模块的数据类型。
- **返回值**:无(这是构造函数,用于初始化模块)。
- **解释**:该方法初始化线性变换模块。它会创建权重矩阵 \( A \)(形状为 `(out_features, in_features)`)和可选的偏置 \( b \)(形状为 `(out_features,)`)。权重和偏置使用均匀分布 \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) 初始化,其中 \( k = \frac{1}{\text{in_features}} \)。这是一个核心方法,用于构建层,例如 `nn.Linear(20, 30)` 创建一个从 20 维输入到 30 维输出的层。
- **注意**:模块支持 TensorFloat32在支持的硬件上加速并且在某些 ROCm 设备上float16 输入会转换为 float32 进行计算。
#### 2. `forward` 方法(前向传播方法)
- **签名**`forward(input)`(输入参数名为 `input`,但文档中未显式列出签名,通常推断自描述)。
- **参数**
- `input`torch.Tensor输入张量形状为 `(*, H_in)`,其中 `*` 表示任意数量的维度(包括无维度),`H_in = in_features`(最后一个维度必须匹配 `in_features`)。
- **返回值**torch.Tensor形状为 `(*, H_out)`,其中前面的维度与输入相同,`H_out = out_features`
- **解释**:这是模块的核心前向传播方法。它对输入应用线性变换 \( y = xA^T + b \)(如果 `bias=True`)。权重 \( A \) 和偏置 \( b \) 是可学习的参数,通过反向传播更新。该方法是 PyTorch 模块的标准接口,在模型构建中调用如 `output = layer(input)` 时隐式执行。
- **示例**(来自官方文档):
```python
import torch.nn as nn
m = nn.Linear(20, 30) # 初始化层
input = torch.randn(128, 20) # 批量大小 128特征 20
output = m(input) # 前向传播
print(output.size()) # 输出: torch.Size([128, 30])
```
这展示了批量输入如何转换为输出形状。
#### 3. `extra_repr` 方法(额外字符串表示方法)
- **签名**`extra_repr()`(无参数,继承自 `nn.Module`)。
- **参数**:无。
- **返回值**str一个字符串总结模块的参数。
- **解释**:该方法返回一个字符串,用于增强模块的打印表示(例如,当你 `print(layer)` 时)。它通常包括 `in_features`、`out_features` 和偏置是否启用(如 "Linear(in_features=20, out_features=30, bias=True)")。这是一个辅助方法,帮助调试和可视化模型结构,不是计算核心。
- **注意**:这是 `nn.Module` 的标准方法,在 `Linear` 中被覆盖以包含特定参数。没有额外示例,但打印模块时会自动调用。
### 附加说明
- **继承方法**`torch.nn.Linear` 继承自 `nn.Module`,因此支持所有 `nn.Module` 方法,如 `parameters()`(获取可训练参数)、`named_parameters()`(命名参数迭代)、`zero_grad()`(梯度清零)等。这些不是 `Linear` 特有,而是通用方法。如果需要讲解整个 `nn.Module`,可以进一步查询。
- **属性**:模块有 `weight`(权重张量)和 `bias`(偏置张量,如果启用),这些是 `nn.Parameter` 类型,可通过优化器更新。
- **使用提示**:在构建神经网络时,`Linear` 常用于 MLP 等。确保输入的最后一个维度匹配 `in_features`,否则会报错。文档强调该模块在 float16 输入时的特殊行为(在 ROCm 上转换为 float32以避免精度问题。
如果需要代码示例、数学推导或特定版本的差异,请提供更多细节!