Files
python/Pytorch/nn/nn.Sequential.md
2025-09-09 15:10:57 +08:00

140 lines
7.3 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

### torch.nn.Sequential 类的概述
根据 PyTorch 官网https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html的最新文档`torch.nn.Sequential` 是一个顺序容器Sequential container它允许你将多个神经网络模块`nn.Module`)按顺序堆叠起来,形成一个简单的前向传播链。模块会按照构造函数中传入的顺序依次执行前向传播:输入数据传递给第一个模块,其输出作为下一个模块的输入,以此类推,最后返回最后一个模块的输出。
这个类继承自 `nn.Module`,因此它继承了 `nn.Module` 的所有通用方法和属性(如 `forward()``parameters()``train()` 等)。但官网针对 `Sequential` 特有的方法和属性进行了详细描述。下面我将逐一讲解所有特有的方法和属性,包括名称、描述、参数、返回值以及示例(基于官网内容)。注意,`Sequential` 的核心是其 `_modules` 属性(一个有序字典,用于存储子模块),其他方法大多围绕它展开。
#### 1. 构造函数 `__init__`
- **描述**:初始化 `Sequential` 容器。你可以直接传入多个 `nn.Module` 对象,或者传入一个 `OrderedDict`(键为字符串,值为模块)。这定义了模块的执行顺序。
- **参数**
- `*args`:可变数量的 `nn.Module` 对象(例如,`nn.Linear(10, 20), nn.ReLU()`)。
- `arg`(可选):一个 `collections.OrderedDict`,其中键是模块的名称(字符串),值是 `nn.Module` 对象。
- **返回值**:无(返回 `Sequential` 实例本身)。
- **注意**
-`nn.ModuleList` 的区别:`ModuleList` 只是一个模块列表,不自动连接层;`Sequential` 会自动将输出连接到下一个输入,形成级联。
- 支持动态添加模块,但初始化时定义的顺序固定。
- **示例**
```python
import torch.nn as nn
from collections import OrderedDict
# 直接传入模块
model = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.Conv2d(20, 64, 5),
nn.ReLU()
)
# 使用 OrderedDict便于命名模块
model = nn.Sequential(
OrderedDict([
("conv1", nn.Conv2d(1, 20, 5)),
("relu1", nn.ReLU()),
("conv2", nn.Conv2d(20, 64, 5)),
("relu2", nn.ReLU()),
])
)
```
当你打印 `model` 时,它会显示所有子模块的结构。
#### 2. 属性:`_modules`
- **描述**:这是一个私有属性(但在文档中被提及),是一个 `collections.OrderedDict`,存储所有子模块。键是模块的名称(字符串或整数索引),值是 `nn.Module` 对象。这是 `Sequential` 内部维护模块顺序的核心数据结构。
- **参数**:无(它是只读的,但可以通过方法如 `append` 等间接修改)。
- **返回值**`OrderedDict` 对象。
- **注意**
- 你不应该直接修改它(例如,不要用 `model._modules['new'] = some_module`),而应使用提供的公共方法(如 `append`、`insert`)来添加或修改模块,以确保正确性。
- 在前向传播中,`_forward_impl` 方法会遍历这个字典,按顺序调用每个模块的 `forward`。
- **示例**
```python
model = nn.Sequential(nn.Linear(1, 2))
print(model._modules) # 输出: OrderedDict([('0', Linear(...))])
```
这有助于调试,但实际使用中很少直接访问。
#### 3. 方法:`append(module)`
- **描述**:将一个给定的模块追加到 `Sequential` 容器的末尾。这会动态扩展模型,而无需重新初始化。
- **参数**
- `module`:要追加的 `nn.Module` 对象(必需)。
- **返回值**`Self`,即 `Sequential` 实例本身(支持链式调用)。
- **注意**
- 追加后,模块会自动获得一个新的索引名称(如从 '0' 开始递增)。
- 这是一个便捷方法,适合在训练过程中动态添加层。
- 源代码位置:`torch/nn/modules/container.py` 中的 Sequential 类。
- **示例**
```python
import torch.nn as nn
n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
n.append(nn.Linear(3, 4))
print(n)
# 输出:
# Sequential(
# (0): Linear(in_features=1, out_features=2, bias=True)
# (1): Linear(in_features=2, out_features=3, bias=True)
# (2): Linear(in_features=3, out_features=4, bias=True)
# )
```
#### 4. 方法:`extend(sequential)`
- **描述**:将另一个 `Sequential` 容器中的所有层扩展(追加)到当前容器的末尾。这相当于合并两个顺序模型。
- **参数**
- `sequential`:另一个 `Sequential` 实例,其中的层将被追加(必需)。
- **返回值**`Self`,即 `Sequential` 实例本身。
- **注意**
- 也可以使用 `+` 操作符实现类似效果,例如 `n + other`,它会返回一个新的 `Sequential` 而非修改原对象。
- 如果 `other` 不是 `Sequential`,可能会引发错误;确保类型匹配。
- 源代码位置:同上。
- **示例**
```python
import torch.nn as nn
n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5))
n.extend(other) # 或者 n = n + other
print(n)
# 输出:
# Sequential(
# (0): Linear(in_features=1, out_features=2, bias=True)
# (1): Linear(in_features=2, out_features=3, bias=True)
# (2): Linear(in_features=3, out_features=4, bias=True)
# (3): Linear(in_features=4, out_features=5, bias=True)
# )
```
#### 5. 方法:`insert(index, module)`
- **描述**:在指定的索引位置插入一个模块到 `Sequential` 容器中。现有模块会根据插入位置向后移位。
- **参数**
- `index`:整数,指定插入位置(从 0 开始;如果超出范围,会追加到末尾)。
- `module`:要插入的 `nn.Module` 对象(必需)。
- **返回值**`Self`,即 `Sequential` 实例本身。
- **注意**
- 如果 `index` 等于当前模块数量,它相当于 `append`。
- 插入后,模块名称会自动调整(例如,插入到索引 1会将原 1 变为 2
- 源代码位置:同上。
- **示例**
```python
import torch.nn as nn
n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
n.insert(1, nn.ReLU()) # 在索引 1 插入 ReLU
print(n)
# 输出:
# Sequential(
# (0): Linear(in_features=1, out_features=2, bias=True)
# (1): ReLU()
# (2): Linear(in_features=2, out_features=3, bias=True)
# )
```
### 附加说明
- **继承的方法**:除了以上特有方法,`Sequential` 还继承了 `nn.Module` 的所有方法,例如:
- `forward(input)`:核心前向传播方法,自动遍历所有子模块。
- `parameters()` / `named_parameters()`:获取模型参数。
- `zero_grad()`、`step()`:用于优化器交互。
- `to(device)`:移动到指定设备(如 GPU
这些在官网的 `nn.Module` 文档中详细描述,这里不赘述。
- **使用建议**`Sequential` 适合构建简单的线性网络(如 MLP 或 CNN 的基本堆叠)。对于复杂结构(如残差连接),推荐使用自定义 `nn.Module` 子类。
- **版本信息**:以上基于 PyTorch 2.4+ 文档(官网持续更新,当前日期 2025-09-09 时为最新)。如果需要代码验证,建议在实际环境中测试。
如果您需要某个方法的更详细代码示例、数学推导或其他 PyTorch 相关问题,请随时补充!