Files
python/Pytorch/nn/nn.Module.md
2025-09-09 15:10:57 +08:00

15 KiB
Raw Blame History

torch.nn.Module 是 PyTorch 中构建神经网络的核心基类,所有的神经网络模型都继承自它。它提供了许多方法来管理模块、参数、子模块和前向传播等功能。以下基于 PyTorch 官方文档(https://pytorch.org/docs/stable/generated/torch.nn.Module.html torch.nn.Module 的所有方法进行详细讲解,力求清晰、简洁且实用。


1. __init__ 方法

  • 描述: 初始化方法,用于定义模块的结构(如层、子模块等)。
  • 用法: 在自定义模块时,重写 __init__ 来定义子模块或参数。
  • 示例:
    import torch.nn as nn
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.fc1 = nn.Linear(10, 5)
    
  • 说明: 调用 super().__init__() 确保基类的初始化,子模块需注册到模块中(通过赋值给 self)。

2. forward(*args, **kwargs)

  • 描述: 定义模块的前向传播逻辑,子类必须实现此方法。
  • 用法: 输入数据通过此方法计算输出。
  • 示例:
    def forward(self, x):
        x = self.fc1(x)
        return x
    
  • 说明: 这是核心方法,所有输入数据的计算流程在此定义。调用模型实例(如 model(x))时会自动调用 forward

3. __call__ 方法

  • 描述: 使模块实例可调用,内部调用 forward 方法并添加钩子hook功能。
  • 用法: 不需要显式重写,用户通过 model(input) 间接调用。
  • 说明: 通常不需要直接操作,但了解它是 model(input) 的实现基础。

4. add_module(name: str, module: Module) -> None

  • 描述: 向模块添加一个子模块,并以指定名称注册。
  • 用法: 动态添加子模块。
  • 示例:
    model = nn.Module()
    model.add_module("fc1", nn.Linear(10, 5))
    
  • 说明: 等价于 self.name = module,但更灵活,适合动态构建模型。

5. apply(fn: Callable[['Module'], None]) -> T

  • 描述: 递归地将函数 fn 应用到模块及其所有子模块。
  • 用法: 用于初始化参数或修改模块属性。
  • 示例:
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
    model.apply(init_weights)
    
  • 说明: 常用于自定义权重初始化。

6. buffers(recurse: bool = True) -> Iterator[Tensor]

  • 描述: 返回模块中所有缓冲区(如 running_meanrunning_var)的迭代器。
  • 用法: 获取模块的非可训练参数(如 BatchNorm 的统计数据)。
  • 示例:
    for buf in model.buffers():
        print(buf)
    
  • 说明: 如果 recurse=True,包括子模块的缓冲区。

7. children() -> Iterator[Module]

  • 描述: 返回直接子模块的迭代器。
  • 用法: 遍历模型的直接子模块。
  • 示例:
    for child in model.children():
        print(child)
    
  • 说明: 仅返回直接子模块,不递归到更深层。

8. cpu() -> T

  • 描述: 将模块的所有参数和缓冲区移动到 CPU。
  • 用法: 用于设备切换。
  • 示例:
    model.cpu()
    
  • 说明: 确保模型和数据在同一设备上运行。

9. cuda(device: Optional[Union[int, torch.device]] = None) -> T

  • 描述: 将模块的所有参数和缓冲区移动到指定的 GPU 设备。
  • 用法: 指定 GPU 设备(如 cuda:0)。
  • 示例:
    model.cuda(0)  # 移动到 GPU 0
    
  • 说明: 如果不指定 device,使用当前默认 GPU。

10. double() -> T

  • 描述: 将模块的参数和缓冲区转换为 torch.float64(双精度浮点)。
  • 用法: 提高数值精度。
  • 示例:
    model.double()
    
  • 说明: 通常在需要高精度计算时使用,但会增加内存占用。

11. eval() -> T

  • 描述: 将模块设置为评估模式。
  • 用法: 关闭训练特有的行为(如 Dropout 和 BatchNorm 的更新)。
  • 示例:
    model.eval()
    
  • 说明: 与 train(False) 等效,用于推理阶段。

12. float() -> T

  • 描述: 将模块的参数和缓冲区转换为 torch.float32(单精度浮点)。
  • 用法: 默认浮点类型,适合大多数深度学习任务。
  • 示例:
    model.float()
    
  • 说明: 比 double() 更节省内存。

13. get_buffer(target: str) -> Tensor

  • 描述: 返回指定名称的缓冲区。
  • 用法: 访问特定缓冲区(如 BatchNorm 的 running_mean)。
  • 示例:
    running_mean = model.get_buffer("bn.running_mean")
    
  • 说明: 如果缓冲区不存在,会抛出 KeyError

14. get_parameter(target: str) -> Parameter

  • 描述: 返回指定名称的参数。
  • 用法: 访问特定参数(如 weightbias)。
  • 示例:
    weight = model.get_parameter("fc1.weight")
    
  • 说明: 参数必须是注册的参数,否则抛出 KeyError

15. get_submodule(target: str) -> Module

  • 描述: 返回指定路径的子模块。
  • 用法: 访问嵌套子模块。
  • 示例:
    submodule = model.get_submodule("block1.conv1")
    
  • 说明: 使用点号(如 block1.conv1)访问嵌套结构。

16. half() -> T

  • 描述: 将模块的参数和缓冲区转换为 torch.float16(半精度浮点)。
  • 用法: 用于加速计算和节省 GPU 内存。
  • 示例:
    model.half()
    
  • 说明: 需确保硬件支持半精度运算。

17. load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> _IncompatibleKeys

  • 描述: 从 state_dict 加载参数和缓冲区。
  • 用法: 用于加载预训练模型或检查点。
  • 示例:
    model.load_state_dict(torch.load("model.pth"))
    
  • 说明:
    • strict=True:要求 state_dict 的键完全匹配。
    • assign=True:直接赋值而非复制(实验性功能)。
    • 返回 _IncompatibleKeys,指示缺失或多余的键。

18. modules() -> Iterator[Module]

  • 描述: 返回模块及其所有子模块(递归)的迭代器。
  • 用法: 遍历整个模块树。
  • 示例:
    for module in model.modules():
        print(module)
    
  • 说明: 比 children() 更深入,包含所有层级子模块。

19. named_buffers(recurse: bool = True, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]

  • 描述: 返回模块中所有缓冲区的名称和值的迭代器。
  • 用法: 获取缓冲区的名称和值。
  • 示例:
    for name, buf in model.named_buffers():
        print(name, buf)
    
  • 说明: 如果 recurse=True,包括子模块的缓冲区。

20. named_children() -> Iterator[Tuple[str, Module]]

  • 描述: 返回直接子模块的名称和模块的迭代器。
  • 用法: 遍历直接子模块及其名称。
  • 示例:
    for name, child in model.named_children():
        print(name, child)
    
  • 说明: 仅返回直接子模块,不递归。

21. named_modules(memo: Optional[Set[Module]] = None, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Module]]

  • 描述: 返回模块及其所有子模块(递归)的名称和模块的迭代器。
  • 用法: 遍历整个模块树及其名称。
  • 示例:
    for name, module in model.named_modules():
        print(name, module)
    
  • 说明: 比 named_children() 更深入,包含所有层级。

22. named_parameters(recurse: bool = True, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Parameter]]

  • 描述: 返回模块中所有参数的名称和值的迭代器。
  • 用法: 访问模型的参数(如权重和偏置)。
  • 示例:
    for name, param in model.named_parameters():
        print(name, param.shape)
    
  • 说明: 如果 recurse=True,包括子模块的参数。

23. parameters(recurse: bool = True) -> Iterator[Parameter]

  • 描述: 返回模块中所有参数的迭代器。
  • 用法: 用于优化器配置。
  • 示例:
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
  • 说明: 如果 recurse=True,包括子模块的参数。

24. register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) -> None

  • 描述: 注册一个缓冲区(如非可训练的张量)。
  • 用法: 用于存储不需要梯度的张量(如 BatchNorm 的 running_mean)。
  • 示例:
    self.register_buffer("running_mean", torch.zeros(10))
    
  • 说明: persistent=True 表示缓冲区会保存到 state_dict

25. register_forward_hook(hook: Callable[..., None], *, prepend: bool = False, with_kwargs: bool = False) -> RemovableHandle

  • 描述: 注册一个前向传播钩子函数,在模块前向传播时调用。
  • 用法: 用于监控或修改前向传播的输入/输出。
  • 示例:
    def hook(module, input, output):
        print(output)
    handle = model.register_forward_hook(hook)
    
  • 说明: 返回 RemovableHandle,可通过 handle.remove() 移除钩子。

26. register_forward_pre_hook(hook: Callable[..., None], *, prepend: bool = False, with_kwargs: bool = False) -> RemovableHandle

  • 描述: 注册一个前向传播前的钩子函数,在 forward 调用前触发。
  • 用法: 用于修改输入或调试。
  • 示例:
    def pre_hook(module, input):
        print(input)
    handle = model.register_forward_pre_hook(pre_hook)
    
  • 说明: 类似 register_forward_hook,但在 forward 之前运行。

27. register_full_backward_hook(hook: Callable[..., None], prepend: bool = False) -> RemovableHandle

  • 描述: 注册一个反向传播钩子函数,在梯度计算时调用。
  • 用法: 用于监控或修改梯度。
  • 示例:
    def backward_hook(module, grad_input, grad_output):
        print(grad_output)
    handle = model.register_full_backward_hook(backward_hook)
    
  • 说明: 在 PyTorch 2.0+ 中推荐使用,替代旧的 register_backward_hook

28. register_parameter(name: str, param: Optional[Parameter]) -> None

  • 描述: 注册一个参数到模块。
  • 用法: 动态添加可训练参数。
  • 示例:
    self.register_parameter("weight", nn.Parameter(torch.randn(10, 5)))
    
  • 说明: 参数会自动加入 parameters()state_dict

29. requires_grad_(requires_grad: bool = True) -> T

  • 描述: 设置模块所有参数的 requires_grad 属性。
  • 用法: 冻结或解冻参数。
  • 示例:
    model.requires_grad_(False)  # 冻结参数
    
  • 说明: 常用于冻结预训练模型的部分层。

30. share_memory() -> T

  • 描述: 将模块的参数和缓冲区移动到共享内存。
  • 用法: 用于多进程数据共享。
  • 示例:
    model.share_memory()
    
  • 说明: 主要用于 torch.multiprocessing 场景。

31. state_dict(*args, destination: Optional[Dict[str, Tensor]] = None, prefix: str = '', keep_vars: bool = False) -> Dict[str, Tensor]

  • 描述: 返回模块的参数和缓冲区的状态字典。
  • 用法: 保存模型状态。
  • 示例:
    torch.save(model.state_dict(), "model.pth")
    
  • 说明: keep_vars=True 时保留 TensorVariable 特性(较少使用)。

32. to(*args, **kwargs) -> T

  • 描述: 将模块的参数和缓冲区移动到指定设备或数据类型。
  • 用法: 灵活的设备/类型转换。
  • 示例:
    model.to(device="cuda", dtype=torch.float16)
    
  • 说明: 支持多种参数形式(如 to(device), to(dtype))。

33. to_empty(*, device: Union[str, torch.device], recurse: bool = True) -> T

  • 描述: 将模块的参数和缓冲区移动到指定设备,但不初始化内容。
  • 用法: 用于初始化空模型。
  • 示例:
    model.to_empty(device="cuda")
    
  • 说明: 较少使用,适合特殊场景。

34. train(mode: bool = True) -> T

  • 描述: 设置模块的训练模式。
  • 用法: 启用/禁用训练特有的行为(如 Dropout、BatchNorm
  • 示例:
    model.train()  # 训练模式
    model.train(False)  # 等同于 eval()
    
  • 说明: 与 eval() 相对,用于切换训练/评估模式。

35. type(dst_type: Union[str, torch.dtype]) -> T

  • 描述: 将模块的参数和缓冲区转换为指定数据类型。
  • 用法: 类似 float(), double() 等,但更通用。
  • 示例:
    model.type(torch.float16)
    
  • 说明: 支持字符串(如 "torch.float32")或 torch.dtype

36. xpu(device: Optional[Union[int, torch.device]] = None) -> T

  • 描述: 将模块移动到 XPU 设备Intel GPU
  • 用法: 用于支持 XPU 的设备。
  • 示例:
    model.xpu(0)
    
  • 说明: 仅在支持 XPU 的环境中有效。

37. zero_grad(set_to_none: bool = False) -> None

  • 描述: 将模块所有参数的梯度清零。
  • 用法: 在优化步骤前调用。
  • 示例:
    model.zero_grad()
    
  • 说明: 如果 set_to_none=True,梯度设为 None 而非 0更节省内存

注意事项

  • 模块管理: 使用 children(), modules(), named_parameters() 等方法可以方便地管理复杂模型的结构和参数。
  • 设备与类型: 确保模型和输入数据在同一设备和数据类型(to(), cuda(), float() 等)。
  • 钩子函数: 钩子(如 register_forward_hook)是调试和动态修改模型的强大工具。
  • 状态保存: state_dict()load_state_dict() 是保存和加载模型的关键方法。

实践建议

你可以尝试以下代码来探索 torch.nn.Module 的功能:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 5)
        self.register_buffer("running_mean", torch.zeros(5))

    def forward(self, x):
        return self.fc1(x)

model = MyModel()
print(list(model.named_parameters()))  # 查看参数
print(list(model.named_buffers()))    # 查看缓冲区
model.to("cuda")                      # 移动到 GPU
model.eval()                          # 设置评估模式
torch.save(model.state_dict(), "model.pth")  # 保存模型

如果你想深入某个方法(如钩子函数或 state_dict)或需要具体代码示例,请告诉我,我可以进一步展开!