Files
python/Pytorch/nn/Sequential/nn.Sequential.md
2025-09-09 15:56:55 +08:00

7.3 KiB
Raw Permalink Blame History

torch.nn.Sequential 类的概述

根据 PyTorch 官网(https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html的最新文档torch.nn.Sequential 是一个顺序容器Sequential container它允许你将多个神经网络模块nn.Module)按顺序堆叠起来,形成一个简单的前向传播链。模块会按照构造函数中传入的顺序依次执行前向传播:输入数据传递给第一个模块,其输出作为下一个模块的输入,以此类推,最后返回最后一个模块的输出。

这个类继承自 nn.Module,因此它继承了 nn.Module 的所有通用方法和属性(如 forward()parameters()train() 等)。但官网针对 Sequential 特有的方法和属性进行了详细描述。下面我将逐一讲解所有特有的方法和属性,包括名称、描述、参数、返回值以及示例(基于官网内容)。注意,Sequential 的核心是其 _modules 属性(一个有序字典,用于存储子模块),其他方法大多围绕它展开。

1. 构造函数 __init__

  • 描述:初始化 Sequential 容器。你可以直接传入多个 nn.Module 对象,或者传入一个 OrderedDict(键为字符串,值为模块)。这定义了模块的执行顺序。
  • 参数
    • *args:可变数量的 nn.Module 对象(例如,nn.Linear(10, 20), nn.ReLU())。
    • arg(可选):一个 collections.OrderedDict,其中键是模块的名称(字符串),值是 nn.Module 对象。
  • 返回值:无(返回 Sequential 实例本身)。
  • 注意
    • nn.ModuleList 的区别:ModuleList 只是一个模块列表,不自动连接层;Sequential 会自动将输出连接到下一个输入,形成级联。
    • 支持动态添加模块,但初始化时定义的顺序固定。
  • 示例
    import torch.nn as nn
    from collections import OrderedDict
    
    # 直接传入模块
    model = nn.Sequential(
        nn.Conv2d(1, 20, 5),
        nn.ReLU(),
        nn.Conv2d(20, 64, 5),
        nn.ReLU()
    )
    
    # 使用 OrderedDict便于命名模块
    model = nn.Sequential(
        OrderedDict([
            ("conv1", nn.Conv2d(1, 20, 5)),
            ("relu1", nn.ReLU()),
            ("conv2", nn.Conv2d(20, 64, 5)),
            ("relu2", nn.ReLU()),
        ])
    )
    
    当你打印 model 时,它会显示所有子模块的结构。

2. 属性:_modules

  • 描述:这是一个私有属性(但在文档中被提及),是一个 collections.OrderedDict,存储所有子模块。键是模块的名称(字符串或整数索引),值是 nn.Module 对象。这是 Sequential 内部维护模块顺序的核心数据结构。
  • 参数:无(它是只读的,但可以通过方法如 append 等间接修改)。
  • 返回值OrderedDict 对象。
  • 注意
    • 你不应该直接修改它(例如,不要用 model._modules['new'] = some_module),而应使用提供的公共方法(如 appendinsert)来添加或修改模块,以确保正确性。
    • 在前向传播中,_forward_impl 方法会遍历这个字典,按顺序调用每个模块的 forward
  • 示例
    model = nn.Sequential(nn.Linear(1, 2))
    print(model._modules)  # 输出: OrderedDict([('0', Linear(...))])
    
    这有助于调试,但实际使用中很少直接访问。

3. 方法:append(module)

  • 描述:将一个给定的模块追加到 Sequential 容器的末尾。这会动态扩展模型,而无需重新初始化。
  • 参数
    • module:要追加的 nn.Module 对象(必需)。
  • 返回值Self,即 Sequential 实例本身(支持链式调用)。
  • 注意
    • 追加后,模块会自动获得一个新的索引名称(如从 '0' 开始递增)。
    • 这是一个便捷方法,适合在训练过程中动态添加层。
    • 源代码位置:torch/nn/modules/container.py 中的 Sequential 类。
  • 示例
    import torch.nn as nn
    
    n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
    n.append(nn.Linear(3, 4))
    print(n)
    # 输出:
    # Sequential(
    #   (0): Linear(in_features=1, out_features=2, bias=True)
    #   (1): Linear(in_features=2, out_features=3, bias=True)
    #   (2): Linear(in_features=3, out_features=4, bias=True)
    # )
    

4. 方法:extend(sequential)

  • 描述:将另一个 Sequential 容器中的所有层扩展(追加)到当前容器的末尾。这相当于合并两个顺序模型。
  • 参数
    • sequential:另一个 Sequential 实例,其中的层将被追加(必需)。
  • 返回值Self,即 Sequential 实例本身。
  • 注意
    • 也可以使用 + 操作符实现类似效果,例如 n + other,它会返回一个新的 Sequential 而非修改原对象。
    • 如果 other 不是 Sequential,可能会引发错误;确保类型匹配。
    • 源代码位置:同上。
  • 示例
    import torch.nn as nn
    
    n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
    other = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 5))
    n.extend(other)  # 或者 n = n + other
    print(n)
    # 输出:
    # Sequential(
    #   (0): Linear(in_features=1, out_features=2, bias=True)
    #   (1): Linear(in_features=2, out_features=3, bias=True)
    #   (2): Linear(in_features=3, out_features=4, bias=True)
    #   (3): Linear(in_features=4, out_features=5, bias=True)
    # )
    

5. 方法:insert(index, module)

  • 描述:在指定的索引位置插入一个模块到 Sequential 容器中。现有模块会根据插入位置向后移位。
  • 参数
    • index:整数,指定插入位置(从 0 开始;如果超出范围,会追加到末尾)。
    • module:要插入的 nn.Module 对象(必需)。
  • 返回值Self,即 Sequential 实例本身。
  • 注意
    • 如果 index 等于当前模块数量,它相当于 append
    • 插入后,模块名称会自动调整(例如,插入到索引 1会将原 1 变为 2
    • 源代码位置:同上。
  • 示例
    import torch.nn as nn
    
    n = nn.Sequential(nn.Linear(1, 2), nn.Linear(2, 3))
    n.insert(1, nn.ReLU())  # 在索引 1 插入 ReLU
    print(n)
    # 输出:
    # Sequential(
    #   (0): Linear(in_features=1, out_features=2, bias=True)
    #   (1): ReLU()
    #   (2): Linear(in_features=2, out_features=3, bias=True)
    # )
    

附加说明

  • 继承的方法:除了以上特有方法,Sequential 还继承了 nn.Module 的所有方法,例如:
    • forward(input):核心前向传播方法,自动遍历所有子模块。
    • parameters() / named_parameters():获取模型参数。
    • zero_grad()step():用于优化器交互。
    • to(device):移动到指定设备(如 GPU。 这些在官网的 nn.Module 文档中详细描述,这里不赘述。
  • 使用建议Sequential 适合构建简单的线性网络(如 MLP 或 CNN 的基本堆叠)。对于复杂结构(如残差连接),推荐使用自定义 nn.Module 子类。
  • 版本信息:以上基于 PyTorch 2.4+ 文档(官网持续更新,当前日期 2025-09-09 时为最新)。如果需要代码验证,建议在实际环境中测试。

如果您需要某个方法的更详细代码示例、数学推导或其他 PyTorch 相关问题,请随时补充!