`torch.nn.Module` 是 PyTorch 中构建神经网络的核心基类,所有的神经网络模型都继承自它。它提供了许多方法来管理模块、参数、子模块和前向传播等功能。以下基于 PyTorch 官方文档(https://pytorch.org/docs/stable/generated/torch.nn.Module.html)对 `torch.nn.Module` 的所有方法进行详细讲解,力求清晰、简洁且实用。 --- ### 1. `__init__` 方法 - **描述**: 初始化方法,用于定义模块的结构(如层、子模块等)。 - **用法**: 在自定义模块时,重写 `__init__` 来定义子模块或参数。 - **示例**: ```python import torch.nn as nn class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.fc1 = nn.Linear(10, 5) ``` - **说明**: 调用 `super().__init__()` 确保基类的初始化,子模块需注册到模块中(通过赋值给 `self`)。 --- ### 2. `forward(*args, **kwargs)` - **描述**: 定义模块的前向传播逻辑,子类必须实现此方法。 - **用法**: 输入数据通过此方法计算输出。 - **示例**: ```python def forward(self, x): x = self.fc1(x) return x ``` - **说明**: 这是核心方法,所有输入数据的计算流程在此定义。调用模型实例(如 `model(x)`)时会自动调用 `forward`。 --- ### 3. `__call__` 方法 - **描述**: 使模块实例可调用,内部调用 `forward` 方法并添加钩子(hook)功能。 - **用法**: 不需要显式重写,用户通过 `model(input)` 间接调用。 - **说明**: 通常不需要直接操作,但了解它是 `model(input)` 的实现基础。 --- ### 4. `add_module(name: str, module: Module) -> None` - **描述**: 向模块添加一个子模块,并以指定名称注册。 - **用法**: 动态添加子模块。 - **示例**: ```python model = nn.Module() model.add_module("fc1", nn.Linear(10, 5)) ``` - **说明**: 等价于 `self.name = module`,但更灵活,适合动态构建模型。 --- ### 5. `apply(fn: Callable[['Module'], None]) -> T` - **描述**: 递归地将函数 `fn` 应用到模块及其所有子模块。 - **用法**: 用于初始化参数或修改模块属性。 - **示例**: ```python def init_weights(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) model.apply(init_weights) ``` - **说明**: 常用于自定义权重初始化。 --- ### 6. `buffers(recurse: bool = True) -> Iterator[Tensor]` - **描述**: 返回模块中所有缓冲区(如 `running_mean` 或 `running_var`)的迭代器。 - **用法**: 获取模块的非可训练参数(如 BatchNorm 的统计数据)。 - **示例**: ```python for buf in model.buffers(): print(buf) ``` - **说明**: 如果 `recurse=True`,包括子模块的缓冲区。 --- ### 7. `children() -> Iterator[Module]` - **描述**: 返回直接子模块的迭代器。 - **用法**: 遍历模型的直接子模块。 - **示例**: ```python for child in model.children(): print(child) ``` - **说明**: 仅返回直接子模块,不递归到更深层。 --- ### 8. `cpu() -> T` - **描述**: 将模块的所有参数和缓冲区移动到 CPU。 - **用法**: 用于设备切换。 - **示例**: ```python model.cpu() ``` - **说明**: 确保模型和数据在同一设备上运行。 --- ### 9. `cuda(device: Optional[Union[int, torch.device]] = None) -> T` - **描述**: 将模块的所有参数和缓冲区移动到指定的 GPU 设备。 - **用法**: 指定 GPU 设备(如 `cuda:0`)。 - **示例**: ```python model.cuda(0) # 移动到 GPU 0 ``` - **说明**: 如果不指定 `device`,使用当前默认 GPU。 --- ### 10. `double() -> T` - **描述**: 将模块的参数和缓冲区转换为 `torch.float64`(双精度浮点)。 - **用法**: 提高数值精度。 - **示例**: ```python model.double() ``` - **说明**: 通常在需要高精度计算时使用,但会增加内存占用。 --- ### 11. `eval() -> T` - **描述**: 将模块设置为评估模式。 - **用法**: 关闭训练特有的行为(如 Dropout 和 BatchNorm 的更新)。 - **示例**: ```python model.eval() ``` - **说明**: 与 `train(False)` 等效,用于推理阶段。 --- ### 12. `float() -> T` - **描述**: 将模块的参数和缓冲区转换为 `torch.float32`(单精度浮点)。 - **用法**: 默认浮点类型,适合大多数深度学习任务。 - **示例**: ```python model.float() ``` - **说明**: 比 `double()` 更节省内存。 --- ### 13. `get_buffer(target: str) -> Tensor` - **描述**: 返回指定名称的缓冲区。 - **用法**: 访问特定缓冲区(如 BatchNorm 的 `running_mean`)。 - **示例**: ```python running_mean = model.get_buffer("bn.running_mean") ``` - **说明**: 如果缓冲区不存在,会抛出 `KeyError`。 --- ### 14. `get_parameter(target: str) -> Parameter` - **描述**: 返回指定名称的参数。 - **用法**: 访问特定参数(如 `weight` 或 `bias`)。 - **示例**: ```python weight = model.get_parameter("fc1.weight") ``` - **说明**: 参数必须是注册的参数,否则抛出 `KeyError`。 --- ### 15. `get_submodule(target: str) -> Module` - **描述**: 返回指定路径的子模块。 - **用法**: 访问嵌套子模块。 - **示例**: ```python submodule = model.get_submodule("block1.conv1") ``` - **说明**: 使用点号(如 `block1.conv1`)访问嵌套结构。 --- ### 16. `half() -> T` - **描述**: 将模块的参数和缓冲区转换为 `torch.float16`(半精度浮点)。 - **用法**: 用于加速计算和节省 GPU 内存。 - **示例**: ```python model.half() ``` - **说明**: 需确保硬件支持半精度运算。 --- ### 17. `load_state_dict(state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False) -> _IncompatibleKeys` - **描述**: 从 `state_dict` 加载参数和缓冲区。 - **用法**: 用于加载预训练模型或检查点。 - **示例**: ```python model.load_state_dict(torch.load("model.pth")) ``` - **说明**: - `strict=True`:要求 `state_dict` 的键完全匹配。 - `assign=True`:直接赋值而非复制(实验性功能)。 - 返回 `_IncompatibleKeys`,指示缺失或多余的键。 --- ### 18. `modules() -> Iterator[Module]` - **描述**: 返回模块及其所有子模块(递归)的迭代器。 - **用法**: 遍历整个模块树。 - **示例**: ```python for module in model.modules(): print(module) ``` - **说明**: 比 `children()` 更深入,包含所有层级子模块。 --- ### 19. `named_buffers(recurse: bool = True, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Tensor]]` - **描述**: 返回模块中所有缓冲区的名称和值的迭代器。 - **用法**: 获取缓冲区的名称和值。 - **示例**: ```python for name, buf in model.named_buffers(): print(name, buf) ``` - **说明**: 如果 `recurse=True`,包括子模块的缓冲区。 --- ### 20. `named_children() -> Iterator[Tuple[str, Module]]` - **描述**: 返回直接子模块的名称和模块的迭代器。 - **用法**: 遍历直接子模块及其名称。 - **示例**: ```python for name, child in model.named_children(): print(name, child) ``` - **说明**: 仅返回直接子模块,不递归。 --- ### 21. `named_modules(memo: Optional[Set[Module]] = None, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Module]]` - **描述**: 返回模块及其所有子模块(递归)的名称和模块的迭代器。 - **用法**: 遍历整个模块树及其名称。 - **示例**: ```python for name, module in model.named_modules(): print(name, module) ``` - **说明**: 比 `named_children()` 更深入,包含所有层级。 --- ### 22. `named_parameters(recurse: bool = True, prefix: str = '', remove_duplicate: bool = True) -> Iterator[Tuple[str, Parameter]]` - **描述**: 返回模块中所有参数的名称和值的迭代器。 - **用法**: 访问模型的参数(如权重和偏置)。 - **示例**: ```python for name, param in model.named_parameters(): print(name, param.shape) ``` - **说明**: 如果 `recurse=True`,包括子模块的参数。 --- ### 23. `parameters(recurse: bool = True) -> Iterator[Parameter]` - **描述**: 返回模块中所有参数的迭代器。 - **用法**: 用于优化器配置。 - **示例**: ```python optimizer = torch.optim.SGD(model.parameters(), lr=0.01) ``` - **说明**: 如果 `recurse=True`,包括子模块的参数。 --- ### 24. `register_buffer(name: str, tensor: Optional[Tensor], persistent: bool = True) -> None` - **描述**: 注册一个缓冲区(如非可训练的张量)。 - **用法**: 用于存储不需要梯度的张量(如 BatchNorm 的 `running_mean`)。 - **示例**: ```python self.register_buffer("running_mean", torch.zeros(10)) ``` - **说明**: `persistent=True` 表示缓冲区会保存到 `state_dict`。 --- ### 25. `register_forward_hook(hook: Callable[..., None], *, prepend: bool = False, with_kwargs: bool = False) -> RemovableHandle` - **描述**: 注册一个前向传播钩子函数,在模块前向传播时调用。 - **用法**: 用于监控或修改前向传播的输入/输出。 - **示例**: ```python def hook(module, input, output): print(output) handle = model.register_forward_hook(hook) ``` - **说明**: 返回 `RemovableHandle`,可通过 `handle.remove()` 移除钩子。 --- ### 26. `register_forward_pre_hook(hook: Callable[..., None], *, prepend: bool = False, with_kwargs: bool = False) -> RemovableHandle` - **描述**: 注册一个前向传播前的钩子函数,在 `forward` 调用前触发。 - **用法**: 用于修改输入或调试。 - **示例**: ```python def pre_hook(module, input): print(input) handle = model.register_forward_pre_hook(pre_hook) ``` - **说明**: 类似 `register_forward_hook`,但在 `forward` 之前运行。 --- ### 27. `register_full_backward_hook(hook: Callable[..., None], prepend: bool = False) -> RemovableHandle` - **描述**: 注册一个反向传播钩子函数,在梯度计算时调用。 - **用法**: 用于监控或修改梯度。 - **示例**: ```python def backward_hook(module, grad_input, grad_output): print(grad_output) handle = model.register_full_backward_hook(backward_hook) ``` - **说明**: 在 PyTorch 2.0+ 中推荐使用,替代旧的 `register_backward_hook`。 --- ### 28. `register_parameter(name: str, param: Optional[Parameter]) -> None` - **描述**: 注册一个参数到模块。 - **用法**: 动态添加可训练参数。 - **示例**: ```python self.register_parameter("weight", nn.Parameter(torch.randn(10, 5))) ``` - **说明**: 参数会自动加入 `parameters()` 和 `state_dict`。 --- ### 29. `requires_grad_(requires_grad: bool = True) -> T` - **描述**: 设置模块所有参数的 `requires_grad` 属性。 - **用法**: 冻结或解冻参数。 - **示例**: ```python model.requires_grad_(False) # 冻结参数 ``` - **说明**: 常用于冻结预训练模型的部分层。 --- ### 30. `share_memory() -> T` - **描述**: 将模块的参数和缓冲区移动到共享内存。 - **用法**: 用于多进程数据共享。 - **示例**: ```python model.share_memory() ``` - **说明**: 主要用于 `torch.multiprocessing` 场景。 --- ### 31. `state_dict(*args, destination: Optional[Dict[str, Tensor]] = None, prefix: str = '', keep_vars: bool = False) -> Dict[str, Tensor]` - **描述**: 返回模块的参数和缓冲区的状态字典。 - **用法**: 保存模型状态。 - **示例**: ```python torch.save(model.state_dict(), "model.pth") ``` - **说明**: `keep_vars=True` 时保留 `Tensor` 的 `Variable` 特性(较少使用)。 --- ### 32. `to(*args, **kwargs) -> T` - **描述**: 将模块的参数和缓冲区移动到指定设备或数据类型。 - **用法**: 灵活的设备/类型转换。 - **示例**: ```python model.to(device="cuda", dtype=torch.float16) ``` - **说明**: 支持多种参数形式(如 `to(device)`, `to(dtype)`)。 --- ### 33. `to_empty(*, device: Union[str, torch.device], recurse: bool = True) -> T` - **描述**: 将模块的参数和缓冲区移动到指定设备,但不初始化内容。 - **用法**: 用于初始化空模型。 - **示例**: ```python model.to_empty(device="cuda") ``` - **说明**: 较少使用,适合特殊场景。 --- ### 34. `train(mode: bool = True) -> T` - **描述**: 设置模块的训练模式。 - **用法**: 启用/禁用训练特有的行为(如 Dropout、BatchNorm)。 - **示例**: ```python model.train() # 训练模式 model.train(False) # 等同于 eval() ``` - **说明**: 与 `eval()` 相对,用于切换训练/评估模式。 --- ### 35. `type(dst_type: Union[str, torch.dtype]) -> T` - **描述**: 将模块的参数和缓冲区转换为指定数据类型。 - **用法**: 类似 `float()`, `double()` 等,但更通用。 - **示例**: ```python model.type(torch.float16) ``` - **说明**: 支持字符串(如 `"torch.float32"`)或 `torch.dtype`。 --- ### 36. `xpu(device: Optional[Union[int, torch.device]] = None) -> T` - **描述**: 将模块移动到 XPU 设备(Intel GPU)。 - **用法**: 用于支持 XPU 的设备。 - **示例**: ```python model.xpu(0) ``` - **说明**: 仅在支持 XPU 的环境中有效。 --- ### 37. `zero_grad(set_to_none: bool = False) -> None` - **描述**: 将模块所有参数的梯度清零。 - **用法**: 在优化步骤前调用。 - **示例**: ```python model.zero_grad() ``` - **说明**: 如果 `set_to_none=True`,梯度设为 `None` 而非 0(更节省内存)。 --- ### 注意事项 - **模块管理**: 使用 `children()`, `modules()`, `named_parameters()` 等方法可以方便地管理复杂模型的结构和参数。 - **设备与类型**: 确保模型和输入数据在同一设备和数据类型(`to()`, `cuda()`, `float()` 等)。 - **钩子函数**: 钩子(如 `register_forward_hook`)是调试和动态修改模型的强大工具。 - **状态保存**: `state_dict()` 和 `load_state_dict()` 是保存和加载模型的关键方法。 ### 实践建议 你可以尝试以下代码来探索 `torch.nn.Module` 的功能: ```python import torch import torch.nn as nn class MyModel(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(10, 5) self.register_buffer("running_mean", torch.zeros(5)) def forward(self, x): return self.fc1(x) model = MyModel() print(list(model.named_parameters())) # 查看参数 print(list(model.named_buffers())) # 查看缓冲区 model.to("cuda") # 移动到 GPU model.eval() # 设置评估模式 torch.save(model.state_dict(), "model.pth") # 保存模型 ``` 如果你想深入某个方法(如钩子函数或 `state_dict`)或需要具体代码示例,请告诉我,我可以进一步展开!