Files
python/Pytorch/nn/viFunc/nn.Linear.md
2025-09-09 15:56:55 +08:00

47 lines
4.7 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

### torch.nn.Linear 类的方法详解
根据 PyTorch 官方文档torch.nn.Linear这是一个实现线性变换全连接层的模块数学公式为 \( y = xA^T + b \),其中 \( A \) 是权重矩阵,\( b \) 是偏置(如果启用)。该类继承自 `nn.Module`,因此除了特定方法外,还支持 `nn.Module` 的通用方法(如 `train()``eval()` 等),但这里重点讲解文档中明确提到的所有方法,包括 `__init__``forward``extra_repr`。文档中未提及其他特定方法。
以下是每个方法的详细讲解,包括方法签名、参数、返回值和解释。参数类型基于 Python 标准类型和 PyTorch 张量规范。
#### 1. `__init__` 方法(初始化方法)
- **签名**`torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)`
- **参数**
- `in_features`int每个输入样本的大小输入特征维度
- `out_features`int每个输出样本的大小输出特征维度
- `bias`bool默认 True如果为 `False`,则不学习可加偏置项。
- `device`torch.device 或 None默认 None指定模块的设备如 CPU 或 GPU
- `dtype`torch.dtype 或 None默认 None指定模块的数据类型。
- **返回值**:无(这是构造函数,用于初始化模块)。
- **解释**:该方法初始化线性变换模块。它会创建权重矩阵 \( A \)(形状为 `(out_features, in_features)`)和可选的偏置 \( b \)(形状为 `(out_features,)`)。权重和偏置使用均匀分布 \(\mathcal{U}(-\sqrt{k}, \sqrt{k})\) 初始化,其中 \( k = \frac{1}{\text{in_features}} \)。这是一个核心方法,用于构建层,例如 `nn.Linear(20, 30)` 创建一个从 20 维输入到 30 维输出的层。
- **注意**:模块支持 TensorFloat32在支持的硬件上加速并且在某些 ROCm 设备上float16 输入会转换为 float32 进行计算。
#### 2. `forward` 方法(前向传播方法)
- **签名**`forward(input)`(输入参数名为 `input`,但文档中未显式列出签名,通常推断自描述)。
- **参数**
- `input`torch.Tensor输入张量形状为 `(*, H_in)`,其中 `*` 表示任意数量的维度(包括无维度),`H_in = in_features`(最后一个维度必须匹配 `in_features`)。
- **返回值**torch.Tensor形状为 `(*, H_out)`,其中前面的维度与输入相同,`H_out = out_features`
- **解释**:这是模块的核心前向传播方法。它对输入应用线性变换 \( y = xA^T + b \)(如果 `bias=True`)。权重 \( A \) 和偏置 \( b \) 是可学习的参数,通过反向传播更新。该方法是 PyTorch 模块的标准接口,在模型构建中调用如 `output = layer(input)` 时隐式执行。
- **示例**(来自官方文档):
```python
import torch.nn as nn
m = nn.Linear(20, 30) # 初始化层
input = torch.randn(128, 20) # 批量大小 128特征 20
output = m(input) # 前向传播
print(output.size()) # 输出: torch.Size([128, 30])
```
这展示了批量输入如何转换为输出形状。
#### 3. `extra_repr` 方法(额外字符串表示方法)
- **签名**`extra_repr()`(无参数,继承自 `nn.Module`)。
- **参数**:无。
- **返回值**str一个字符串总结模块的参数。
- **解释**:该方法返回一个字符串,用于增强模块的打印表示(例如,当你 `print(layer)` 时)。它通常包括 `in_features`、`out_features` 和偏置是否启用(如 "Linear(in_features=20, out_features=30, bias=True)")。这是一个辅助方法,帮助调试和可视化模型结构,不是计算核心。
- **注意**:这是 `nn.Module` 的标准方法,在 `Linear` 中被覆盖以包含特定参数。没有额外示例,但打印模块时会自动调用。
### 附加说明
- **继承方法**`torch.nn.Linear` 继承自 `nn.Module`,因此支持所有 `nn.Module` 方法,如 `parameters()`(获取可训练参数)、`named_parameters()`(命名参数迭代)、`zero_grad()`(梯度清零)等。这些不是 `Linear` 特有,而是通用方法。如果需要讲解整个 `nn.Module`,可以进一步查询。
- **属性**:模块有 `weight`(权重张量)和 `bias`(偏置张量,如果启用),这些是 `nn.Parameter` 类型,可通过优化器更新。
- **使用提示**:在构建神经网络时,`Linear` 常用于 MLP 等。确保输入的最后一个维度匹配 `in_features`,否则会报错。文档强调该模块在 float16 输入时的特殊行为(在 ROCm 上转换为 float32以避免精度问题。
如果需要代码示例、数学推导或特定版本的差异,请提供更多细节!