### torch.nn.Sequential 类的概述 根据 PyTorch 官网(https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)的最新文档,`torch.nn.Sequential` 是一个顺序容器(Sequential container),它允许你将多个神经网络模块(`nn.Module`)按顺序堆叠起来,形成一个简单的前向传播链。模块会按照构造函数中传入的顺序依次执行前向传播:输入数据传递给第一个模块,其输出作为下一个模块的输入,以此类推,最后返回最后一个模块的输出。 这个类继承自 `nn.Module`,因此它继承了 `nn.Module` 的所有通用方法和属性(如 `forward()`、`parameters()`、`train()` 等)。但官网针对 `Sequential` 特有的方法和属性进行了详细描述。下面我将逐一讲解所有特有的方法和属性,包括名称、描述、参数、返回值以及示例(基于官网内容)。注意,`Sequential` 的核心是其 `_modules` 属性(一个有序字典,用于存储子模块),其他方法大多围绕它展开。 #### 1. 构造函数 `__init__` - **描述**:初始化 `Sequential` 容器。你可以直接传入多个 `nn.Module` 对象,或者传入一个 `OrderedDict`(键为字符串,值为模块)。这定义了模块的执行顺序。 - **参数**: - `*args`:可变数量的 `nn.Module` 对象(例如,`nn.Linear(10, 20), nn.ReLU()`)。 - `arg`(可选):一个 `collections.OrderedDict`,其中键是模块的名称(字符串),值是 `nn.Module` 对象。 - **返回值**:无(返回 `Sequential` 实例本身)。 - **注意**: - 与 `nn.ModuleList` 的区别:`ModuleList` 只是一个模块列表,不自动连接层;`Sequential` 会自动将输出连接到下一个输入,形成级联。 - 支持动态添加模块,但初始化时定义的顺序固定。 - **示例**: ```python import torch.nn as nn from collections import OrderedDict # 直接传入模块 model = nn.Sequential( nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ) # 使用 OrderedDict(便于命名模块) model = nn.Sequential( OrderedDict([ ("conv1", nn.Conv2d(1, 20, 5)), ("relu1", nn.ReLU()), ("conv2", nn.Conv2d(20, 64, 5)), ("relu2", nn.ReLU()), ]) ) ``` 当你打印 `model` 时,它会显示所有子模块的结构。 #### 2. 属性:`_modules` - **描述**:这是一个私有属性(但在文档中被提及),是一个 `collections.OrderedDict`,存储所有子模块。键是模块的名称(字符串或整数索引),值是 `nn.Module` 对象。这是 `Sequential` 内部维护模块顺序的核心数据结构。 - **参数**:无(它是只读的,但可以通过方法如 `append` 等间接修改)。 - **返回值**:`OrderedDict` 对象。 - **注意**: - 你不应该直接修改它(例如,不要用 `model._modules['new'] = some_module`),而应使用提供的公共方法(如 `append`、`insert`)来添加或修改模块,以确保正确性。 - 在前向传播中,`_forward_impl` 方法会遍历这个字典,按顺序调用每个模块的 `forward`。 - **示例**: ```python model = nn.Sequential(nn.Linear(1, 2)) print(model._modules) # 输出: OrderedDict([('0', Linear(...))]) ``` 这有助于调试,但实际使用中很少直接访问。 #### 3. 方法:`append(module)` - **描述**:将一个给定的模块追加到 `Sequential` 容器的末尾。这会动态扩展模型,而无需重新初始化。 - **参数**: - `module`:要追加的 `nn.Module` 对象(必需)。 - **返回值**:`Self`,即 `Sequential` 实例本身(支持链式调用)。 - **注意**: - 追加后,模块会自动获得一个新的索引名称(如从 '0' 开始递增)。 - 这是一个便捷方法,适合在训练过程中动态添加层。 - 源代码位置:`torch/nn/modules/container.py` 中的 Sequential 类。 - **示例**: ```python import torch.nn as nn n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) n.append(nn.Linear(3, 4)) print(n) # 输出: # Sequential( # (0): Linear(in_features=1, out_features=2, bias=True) # (1): Linear(in_features=2, out_features=3, bias=True) # (2): Linear(in_features=3, out_features=4, bias=True) # ) ``` #### 4. 方法:`extend(sequential)` - **描述**:将另一个 `Sequential` 容器中的所有层扩展(追加)到当前容器的末尾。这相当于合并两个顺序模型。 - **参数**: - `sequential`:另一个 `Sequential` 实例,其中的层将被追加(必需)。 - **返回值**:`Self`,即 `Sequential` 实例本身。 - **注意**: - 也可以使用 `+` 操作符实现类似效果,例如 `n + other`,它会返回一个新的 `Sequential` 而非修改原对象。 - 如果 `other` 不是 `Sequential`,可能会引发错误;确保类型匹配。 - 源代码位置:同上。 - **示例**: ```python import torch.nn as nn n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5)) n.extend(other) # 或者 n = n + other print(n) # 输出: # Sequential( # (0): Linear(in_features=1, out_features=2, bias=True) # (1): Linear(in_features=2, out_features=3, bias=True) # (2): Linear(in_features=3, out_features=4, bias=True) # (3): Linear(in_features=4, out_features=5, bias=True) # ) ``` #### 5. 方法:`insert(index, module)` - **描述**:在指定的索引位置插入一个模块到 `Sequential` 容器中。现有模块会根据插入位置向后移位。 - **参数**: - `index`:整数,指定插入位置(从 0 开始;如果超出范围,会追加到末尾)。 - `module`:要插入的 `nn.Module` 对象(必需)。 - **返回值**:`Self`,即 `Sequential` 实例本身。 - **注意**: - 如果 `index` 等于当前模块数量,它相当于 `append`。 - 插入后,模块名称会自动调整(例如,插入到索引 1,会将原 1 变为 2)。 - 源代码位置:同上。 - **示例**: ```python import torch.nn as nn n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3)) n.insert(1, nn.ReLU()) # 在索引 1 插入 ReLU print(n) # 输出: # Sequential( # (0): Linear(in_features=1, out_features=2, bias=True) # (1): ReLU() # (2): Linear(in_features=2, out_features=3, bias=True) # ) ``` ### 附加说明 - **继承的方法**:除了以上特有方法,`Sequential` 还继承了 `nn.Module` 的所有方法,例如: - `forward(input)`:核心前向传播方法,自动遍历所有子模块。 - `parameters()` / `named_parameters()`:获取模型参数。 - `zero_grad()`、`step()`:用于优化器交互。 - `to(device)`:移动到指定设备(如 GPU)。 这些在官网的 `nn.Module` 文档中详细描述,这里不赘述。 - **使用建议**:`Sequential` 适合构建简单的线性网络(如 MLP 或 CNN 的基本堆叠)。对于复杂结构(如残差连接),推荐使用自定义 `nn.Module` 子类。 - **版本信息**:以上基于 PyTorch 2.4+ 文档(官网持续更新,当前日期 2025-09-09 时为最新)。如果需要代码验证,建议在实际环境中测试。 如果您需要某个方法的更详细代码示例、数学推导或其他 PyTorch 相关问题,请随时补充!