7.3 KiB
7.3 KiB
torch.nn.Sequential 类的概述
根据 PyTorch 官网(https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)的最新文档,torch.nn.Sequential
是一个顺序容器(Sequential container),它允许你将多个神经网络模块(nn.Module
)按顺序堆叠起来,形成一个简单的前向传播链。模块会按照构造函数中传入的顺序依次执行前向传播:输入数据传递给第一个模块,其输出作为下一个模块的输入,以此类推,最后返回最后一个模块的输出。
这个类继承自 nn.Module
,因此它继承了 nn.Module
的所有通用方法和属性(如 forward()
、parameters()
、train()
等)。但官网针对 Sequential
特有的方法和属性进行了详细描述。下面我将逐一讲解所有特有的方法和属性,包括名称、描述、参数、返回值以及示例(基于官网内容)。注意,Sequential
的核心是其 _modules
属性(一个有序字典,用于存储子模块),其他方法大多围绕它展开。
1. 构造函数 __init__
- 描述:初始化
Sequential
容器。你可以直接传入多个nn.Module
对象,或者传入一个OrderedDict
(键为字符串,值为模块)。这定义了模块的执行顺序。 - 参数:
*args
:可变数量的nn.Module
对象(例如,nn.Linear(10, 20), nn.ReLU()
)。arg
(可选):一个collections.OrderedDict
,其中键是模块的名称(字符串),值是nn.Module
对象。
- 返回值:无(返回
Sequential
实例本身)。 - 注意:
- 与
nn.ModuleList
的区别:ModuleList
只是一个模块列表,不自动连接层;Sequential
会自动将输出连接到下一个输入,形成级联。 - 支持动态添加模块,但初始化时定义的顺序固定。
- 与
- 示例:
当你打印
import torch.nn as nn from collections import OrderedDict # 直接传入模块 model = nn.Sequential( nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ) # 使用 OrderedDict(便于命名模块) model = nn.Sequential( OrderedDict([ ("conv1", nn.Conv2d(1, 20, 5)), ("relu1", nn.ReLU()), ("conv2", nn.Conv2d(20, 64, 5)), ("relu2", nn.ReLU()), ]) )
model
时,它会显示所有子模块的结构。
2. 属性:_modules
- 描述:这是一个私有属性(但在文档中被提及),是一个
collections.OrderedDict
,存储所有子模块。键是模块的名称(字符串或整数索引),值是nn.Module
对象。这是Sequential
内部维护模块顺序的核心数据结构。 - 参数:无(它是只读的,但可以通过方法如
append
等间接修改)。 - 返回值:
OrderedDict
对象。 - 注意:
- 你不应该直接修改它(例如,不要用
model._modules['new'] = some_module
),而应使用提供的公共方法(如append
、insert
)来添加或修改模块,以确保正确性。 - 在前向传播中,
_forward_impl
方法会遍历这个字典,按顺序调用每个模块的forward
。
- 你不应该直接修改它(例如,不要用
- 示例:
这有助于调试,但实际使用中很少直接访问。
model = nn.Sequential(nn.Linear(1, 2)) print(model._modules) # 输出: OrderedDict([('0', Linear(...))])
3. 方法:append(module)
- 描述:将一个给定的模块追加到
Sequential
容器的末尾。这会动态扩展模型,而无需重新初始化。 - 参数:
module
:要追加的nn.Module
对象(必需)。
- 返回值:
Self
,即Sequential
实例本身(支持链式调用)。 - 注意:
- 追加后,模块会自动获得一个新的索引名称(如从 '0' 开始递增)。
- 这是一个便捷方法,适合在训练过程中动态添加层。
- 源代码位置:
torch/nn/modules/container.py
中的 Sequential 类。
- 示例:
import torch.nn as nn n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) n.append(nn.Linear(3, 4)) print(n) # 输出: # Sequential( # (0): Linear(in_features=1, out_features=2, bias=True) # (1): Linear(in_features=2, out_features=3, bias=True) # (2): Linear(in_features=3, out_features=4, bias=True) # )
4. 方法:extend(sequential)
- 描述:将另一个
Sequential
容器中的所有层扩展(追加)到当前容器的末尾。这相当于合并两个顺序模型。 - 参数:
sequential
:另一个Sequential
实例,其中的层将被追加(必需)。
- 返回值:
Self
,即Sequential
实例本身。 - 注意:
- 也可以使用
+
操作符实现类似效果,例如n + other
,它会返回一个新的Sequential
而非修改原对象。 - 如果
other
不是Sequential
,可能会引发错误;确保类型匹配。 - 源代码位置:同上。
- 也可以使用
- 示例:
import torch.nn as nn n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5)) n.extend(other) # 或者 n = n + other print(n) # 输出: # Sequential( # (0): Linear(in_features=1, out_features=2, bias=True) # (1): Linear(in_features=2, out_features=3, bias=True) # (2): Linear(in_features=3, out_features=4, bias=True) # (3): Linear(in_features=4, out_features=5, bias=True) # )
5. 方法:insert(index, module)
- 描述:在指定的索引位置插入一个模块到
Sequential
容器中。现有模块会根据插入位置向后移位。 - 参数:
index
:整数,指定插入位置(从 0 开始;如果超出范围,会追加到末尾)。module
:要插入的nn.Module
对象(必需)。
- 返回值:
Self
,即Sequential
实例本身。 - 注意:
- 如果
index
等于当前模块数量,它相当于append
。 - 插入后,模块名称会自动调整(例如,插入到索引 1,会将原 1 变为 2)。
- 源代码位置:同上。
- 如果
- 示例:
import torch.nn as nn n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) n.insert(1, nn.ReLU()) # 在索引 1 插入 ReLU print(n) # 输出: # Sequential( # (0): Linear(in_features=1, out_features=2, bias=True) # (1): ReLU() # (2): Linear(in_features=2, out_features=3, bias=True) # )
附加说明
- 继承的方法:除了以上特有方法,
Sequential
还继承了nn.Module
的所有方法,例如:forward(input)
:核心前向传播方法,自动遍历所有子模块。parameters()
/named_parameters()
:获取模型参数。zero_grad()
、step()
:用于优化器交互。to(device)
:移动到指定设备(如 GPU)。 这些在官网的nn.Module
文档中详细描述,这里不赘述。
- 使用建议:
Sequential
适合构建简单的线性网络(如 MLP 或 CNN 的基本堆叠)。对于复杂结构(如残差连接),推荐使用自定义nn.Module
子类。 - 版本信息:以上基于 PyTorch 2.4+ 文档(官网持续更新,当前日期 2025-09-09 时为最新)。如果需要代码验证,建议在实际环境中测试。
如果您需要某个方法的更详细代码示例、数学推导或其他 PyTorch 相关问题,请随时补充!